Skip to content
This repository has been archived by the owner on Dec 12, 2024. It is now read-only.

Commit

Permalink
Add conversions for tosa.slice and tosa.cast (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
OliverScherf authored Apr 20, 2021
1 parent d49e68f commit 4de3eec
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 75 deletions.
2 changes: 2 additions & 0 deletions docs/tosa-op-coverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The table below shows the supported TOSA ops.
| const | :heavy_check_mark: | |
| **Unary elementwise ops**
| abs | :heavy_check_mark: | |
| cast | :heavy_check_mark: | |
| ceil | :heavy_check_mark: | |
| clamp | :heavy_check_mark: | |
| exp | :heavy_check_mark: | |
Expand Down Expand Up @@ -37,5 +38,6 @@ The table below shows the supported TOSA ops.
| reduce_prod | :heavy_check_mark: | |
| reduce_sum | :heavy_check_mark: | |
| reshape | :heavy_check_mark: | |
| slice | :white_check_mark: | Only for 1D to 4D inputs |
| pad | :white_check_mark: | Quantization not supported |
| transpose | :heavy_check_mark: | |
83 changes: 83 additions & 0 deletions include/emitc/emitc_core_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ inline Src ceil(Src x) {
return unary<Src>(x, f);
}

// ConvertOp
template <typename Dest, typename Src>
inline Dest convert(Src x) {
using ET_Dest = typename get_element_type<Dest>::type;
using ET_Src = typename get_element_type<Src>::type;

auto cast = [](ET_Src value) { return static_cast<ET_Dest>(value); };

return unary<Dest, Src, UnaryFuncType<ET_Dest, ET_Src>>(x, cast);
}

// ExpOp
template <typename Src>
inline Src exp(Src x) {
Expand Down Expand Up @@ -314,6 +325,78 @@ inline Dest reshape(Src x) {
return z;
}

// SliceOp
// Overload for 1d case.
template <typename Dest, typename Src, IsTensorOfDim<1, Src> = true>
Dest slice(Src x, Tensor<int64_t, 1> start_indices,
Tensor<int64_t, 1> limit_indices, Tensor<int64_t, 1> strides) {
Dest z;

size_t index = 0;
for (int64_t i = start_indices[0]; i < limit_indices[0]; i += strides[0]) {
z[index++] = x(i);
}

return z;
}

// Overload for 2d case.
template <typename Dest, typename Src, IsTensorOfDim<2, Src> = true>
Dest slice(Src x, Tensor<int64_t, 2> start_indices,
Tensor<int64_t, 2> limit_indices, Tensor<int64_t, 2> strides) {
Dest z;

size_t index = 0;
for (int64_t i = start_indices[0]; i < limit_indices[0]; i += strides[0]) {
for (int64_t j = start_indices[1]; j < limit_indices[1]; j += strides[1]) {
z[index++] = x(i, j);
}
}

return z;
}

// Overload for 3d case.
template <typename Dest, typename Src, IsTensorOfDim<3, Src> = true>
Dest slice(Src x, Tensor<int64_t, 3> start_indices,
Tensor<int64_t, 3> limit_indices, Tensor<int64_t, 3> strides) {
Dest z;

size_t index = 0;
for (int64_t i = start_indices[0]; i < limit_indices[0]; i += strides[0]) {
for (int64_t j = start_indices[1]; j < limit_indices[1]; j += strides[1]) {
for (int64_t k = start_indices[2]; k < limit_indices[2];
k += strides[2]) {
z[index++] = x(i, j, k);
}
}
}

return z;
}

// Overload for 4d case.
template <typename Dest, typename Src, IsTensorOfDim<4, Src> = true>
Dest slice(Src x, Tensor<int64_t, 4> start_indices,
Tensor<int64_t, 4> limit_indices, Tensor<int64_t, 4> strides) {
Dest z;

size_t index = 0;
for (int64_t i = start_indices[0]; i < limit_indices[0]; i += strides[0]) {
for (int64_t j = start_indices[1]; j < limit_indices[1]; j += strides[1]) {
for (int64_t k = start_indices[2]; k < limit_indices[2];
k += strides[2]) {
for (int64_t c = start_indices[3]; c < limit_indices[3];
c += strides[3]) {
z[index++] = x(i, j, k, c);
}
}
}
}

return z;
}

// PadOp
// TODO: Add support for negative edge padding
template <typename Dest, typename Src>
Expand Down
81 changes: 6 additions & 75 deletions include/emitc/emitc_mhlo.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,7 @@ typename replace_element_type<bool, Src>::type compare(Src x, Src y) {
// ConvertOp
template <typename Dest, typename Src>
inline Dest convert(Src x) {
using ET_Dest = typename get_element_type<Dest>::type;
using ET_Src = typename get_element_type<Src>::type;

auto cast = [](ET_Src value) { return static_cast<ET_Dest>(value); };

return unary<Dest, Src, UnaryFuncType<ET_Dest, ET_Src>>(x, cast);
return emitc::convert<Dest>(x);
}

// CosOp
Expand Down Expand Up @@ -356,75 +351,11 @@ inline Dest concatenate(Src1 input1, Src... inputs) {
}

// SliceOp
// Overload for 1d case.
template <typename Dest, typename Src, IsTensorOfDim<1, Src> = true>
Dest slice(Src x, Tensor<int64_t, 1> start_indices,
Tensor<int64_t, 1> limit_indices, Tensor<int64_t, 1> strides) {
Dest z;

size_t index = 0;
for (int64_t i = start_indices[0]; i < limit_indices[0]; i += strides[0]) {
z[index++] = x(i);
}

return z;
}

// Overload for 2d case.
template <typename Dest, typename Src, IsTensorOfDim<2, Src> = true>
Dest slice(Src x, Tensor<int64_t, 2> start_indices,
Tensor<int64_t, 2> limit_indices, Tensor<int64_t, 2> strides) {
Dest z;

size_t index = 0;
for (int64_t i = start_indices[0]; i < limit_indices[0]; i += strides[0]) {
for (int64_t j = start_indices[1]; j < limit_indices[1]; j += strides[1]) {
z[index++] = x(i, j);
}
}

return z;
}

// Overload for 3d case.
template <typename Dest, typename Src, IsTensorOfDim<3, Src> = true>
Dest slice(Src x, Tensor<int64_t, 3> start_indices,
Tensor<int64_t, 3> limit_indices, Tensor<int64_t, 3> strides) {
Dest z;

size_t index = 0;
for (int64_t i = start_indices[0]; i < limit_indices[0]; i += strides[0]) {
for (int64_t j = start_indices[1]; j < limit_indices[1]; j += strides[1]) {
for (int64_t k = start_indices[2]; k < limit_indices[2];
k += strides[2]) {
z[index++] = x(i, j, k);
}
}
}

return z;
}

// Overload for 4d case.
template <typename Dest, typename Src, IsTensorOfDim<4, Src> = true>
Dest slice(Src x, Tensor<int64_t, 4> start_indices,
Tensor<int64_t, 4> limit_indices, Tensor<int64_t, 4> strides) {
Dest z;

size_t index = 0;
for (int64_t i = start_indices[0]; i < limit_indices[0]; i += strides[0]) {
for (int64_t j = start_indices[1]; j < limit_indices[1]; j += strides[1]) {
for (int64_t k = start_indices[2]; k < limit_indices[2];
k += strides[2]) {
for (int64_t c = start_indices[3]; c < limit_indices[3];
c += strides[3]) {
z[index++] = x(i, j, k, c);
}
}
}
}

return z;
template <typename Dest, typename Src>
Dest slice(Src x, Tensor<int64_t, Src::rank()> start_indices,
Tensor<int64_t, Src::rank()> limit_indices,
Tensor<int64_t, Src::rank()> strides) {
return emitc::slice<Dest, Src>(x, start_indices, limit_indices, strides);
}

// DynamicSliceOp
Expand Down
18 changes: 18 additions & 0 deletions include/emitc/emitc_tosa.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <limits>

#include "emitc_core_ops.h"
#include "emitc_std.h"

namespace tosa {

Expand All @@ -28,6 +29,12 @@ inline Src abs(Src x) {
return emitc::abs<Src>(x);
}

// CastOp
template <typename Dest, typename Src>
inline Dest cast(Src x) {
return emitc::convert<Dest>(x);
}

// CeilOp
template <typename Src>
inline Src ceil(Src x) {
Expand Down Expand Up @@ -481,6 +488,17 @@ inline Dest reshape(Src x) {
return emitc::reshape<Dest>(x);
}

// SliceOp
template <typename Dest, typename Src>
Dest slice(Src x, Tensor<int64_t, Src::rank()> start_indices,
Tensor<int64_t, Src::rank()> slice_sizes) {
Tensor<int64_t, Src::rank()> limit_indices =
emitc::add(start_indices, slice_sizes);
Tensor<int64_t, Src::rank()> strides =
standard::splat<Tensor<int64_t, Src::rank()>>(1);
return emitc::slice<Dest, Src>(x, start_indices, limit_indices, strides);
}

// PadOp
template <typename Dest, typename Src, typename Padding>
inline Dest pad(Src operand, Padding padding) {
Expand Down
37 changes: 37 additions & 0 deletions lib/Dialect/EmitC/Conversion/TosaToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,38 @@ class PadOpConversion : public OpConversionPattern<tosa::PadOp> {
}
};

/// Convert `tosa.slice` into an `emitc.call` operation.
class SliceOpConversion : public OpConversionPattern<tosa::SliceOp> {
using OpConversionPattern<tosa::SliceOp>::OpConversionPattern;

public:
SliceOpConversion(MLIRContext *ctx)
: OpConversionPattern<tosa::SliceOp>(ctx) {}

private:
LogicalResult
matchAndRewrite(tosa::SliceOp sliceOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
StringAttr callee = rewriter.getStringAttr("tosa::slice");

// clang-format off
ArrayAttr args = rewriter.getArrayAttr({
rewriter.getIndexAttr(0),
getI64ElementsAttr(sliceOp.startAttr(), sliceOp.getContext()),
getI64ElementsAttr(sliceOp.sizeAttr(), sliceOp.getContext()),
});
// clang-format on

Type resultType = sliceOp.output().getType();
ArrayAttr templateArgs = rewriter.getArrayAttr({TypeAttr::get(resultType)});

rewriter.replaceOpWithNewOp<emitc::CallOp>(
sliceOp, sliceOp.getType(), callee, args, templateArgs, operands);

return success();
}
};

} // namespace

void populateTosaToEmitcPatterns(MLIRContext *ctx,
Expand All @@ -656,6 +688,8 @@ void populateTosaToEmitcPatterns(MLIRContext *ctx,

// Insert patterns for TOSA unary elementwise ops.
patterns.insert<CallOpConversion<tosa::AbsOp>>(ctx, "tosa::abs");
patterns.insert<CallOpConversion<tosa::CastOp>>(ctx, "tosa::cast",
/*explicitResultType=*/true);
patterns.insert<CallOpConversion<tosa::CeilOp>>(ctx, "tosa::ceil");
patterns.insert<ClampOpConversion>(ctx);
patterns.insert<CallOpConversion<tosa::ExpOp>>(ctx, "tosa::exp");
Expand Down Expand Up @@ -698,6 +732,7 @@ void populateTosaToEmitcPatterns(MLIRContext *ctx,
"tosa::reduce_sum");
patterns.insert<CallOpConversion<tosa::ReshapeOp>>(
ctx, "tosa::reshape", /*explicitResultType=*/true);
patterns.insert<SliceOpConversion>(ctx);
patterns.insert<PadOpConversion>(ctx);
patterns.insert<CallOpConversion<tosa::TransposeOp>>(
ctx, "tosa::transpose", /*explicitResultType=*/true);
Expand All @@ -724,6 +759,7 @@ struct ConvertTosaToEmitCPass

// Unary elementwise ops
target.addIllegalOp<tosa::AbsOp>();
target.addIllegalOp<tosa::CastOp>();
target.addIllegalOp<tosa::CeilOp>();
target.addIllegalOp<tosa::ClampOp>();
target.addIllegalOp<tosa::ExpOp>();
Expand Down Expand Up @@ -755,6 +791,7 @@ struct ConvertTosaToEmitCPass
target.addIllegalOp<tosa::ReduceProdOp>();
target.addIllegalOp<tosa::ReduceSumOp>();
target.addIllegalOp<tosa::ReshapeOp>();
target.addIllegalOp<tosa::SliceOp>();
target.addIllegalOp<tosa::PadOp>();
target.addIllegalOp<tosa::TransposeOp>();

Expand Down
13 changes: 13 additions & 0 deletions test/Conversion/tosa-to-emitc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ func @test_abs(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
return %0 : tensor<13x21x3xf32>
}

// CHECK-LABEL: cast
func @test_cast(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
// CHECK: %0 = emitc.call "tosa::cast"(%arg0) {template_args = [tensor<13x21x3xf32>]} : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32>
%0 = "tosa.cast"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}

func @test_ceil(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// CHECK: emitc.call "tosa::ceil"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
%0 = "tosa.ceil"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
Expand Down Expand Up @@ -235,6 +242,12 @@ func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<13x1x3xf32> {
return %0 : tensor<13x1x3xf32>
}

func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> {
// CHECK: %0 = emitc.call "tosa::slice"(%arg0) {args = [0 : index, dense<[6, 8, 0]> : tensor<3xi64>, dense<[4, 11, 1]> : tensor<3xi64>], template_args = [tensor<4x11x1xf32>]} : (tensor<13x21x3xf32>) -> tensor<4x11x1xf32>
%0 = "tosa.slice"(%arg0) {start = [6, 8, 0], size = [4, 11, 1]} : (tensor<13x21x3xf32>) -> tensor<4x11x1xf32>
return %0 : tensor<4x11x1xf32>
}

func @test_pad(%arg0: tensor<2x3xf32>, %arg1: tensor<2x2xi32>) -> tensor<3x6xf32> {
// CHECK: %0 = emitc.call "tosa::pad"(%arg0, %arg1) {template_args = [tensor<3x6xf32>]} : (tensor<2x3xf32>, tensor<2x2xi32>) -> tensor<3x6xf32>
%0 = "tosa.pad"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2x2xi32>) -> tensor<3x6xf32>
Expand Down
37 changes: 37 additions & 0 deletions unittests/emitc_tosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,43 @@ TEST(tosa, reshape) {
EXPECT_THAT(s3, Pointwise(Eq(), t3));
}

TEST(tosa, slice) {
// Slice Tensor1D
Tensor1D<float, 5> s1{0.0f, 1.0f, 2.0f, 3.0f, 4.0f};
auto t1 = tosa::slice<Tensor1D<float, 2>>(s1, {2}, {2});
EXPECT_THAT(t1, Pointwise(FloatEq(), {2.0f, 3.0f}));

// Slice Tensor2D
Tensor2D<float, 4, 3> s2{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f};
auto t2 = tosa::slice<Tensor2D<float, 2, 2>>(s2, {2, 1}, {2, 2});

EXPECT_THAT(t2, Pointwise(FloatEq(), {7.0f, 8.0f, 10.0f, 11.0f}));

// Slice Tensor3D
Tensor3D<float, 4, 3, 2> s3{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f,
12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f,
18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f};
auto t3 = tosa::slice<Tensor3D<float, 2, 2, 2>>(s3, {2, 1, 0}, {2, 2, 2});
EXPECT_THAT(t3, Pointwise(FloatEq(), {14.0f, 15.0f, 16.0f, 17.0f, 20.0f,
21.0f, 22.0f, 23.0f}));

// Slice Tensor4D
Tensor4D<float, 4, 3, 1, 2> s4{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f,
12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f,
18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f};
auto t4 =
tosa::slice<Tensor4D<float, 2, 2, 1, 2>>(s4, {2, 1, 0, 0}, {2, 2, 1, 2});
EXPECT_THAT(t4, Pointwise(FloatEq(), {14.0f, 15.0f, 16.0f, 17.0f, 20.0f,
21.0f, 22.0f, 23.0f}));

auto t4_2 =
tosa::slice<Tensor4D<float, 4, 3, 1, 2>>(s4, {0, 0, 0, 0}, {4, 3, 1, 2});
EXPECT_THAT(t4_2, Pointwise(FloatEq(), s4));
}

TEST(tosa, pad) {
// clang-format off
Tensor<int32_t, 2, 3> operand0{1, 2, 3,
Expand Down

0 comments on commit 4de3eec

Please sign in to comment.