Skip to content

Commit

Permalink
[mlir][tosa] Add verifiers to ReduceOps, fix shape inference crash (l…
Browse files Browse the repository at this point in the history
…lvm#69843)

This patch adds verifiers to `tosa.reduce_*` ops that check, among other things,
that the supplied `axis` argument is compatible with the input/output tensors'
shapes. We allow for a special case of `axis == 0 && rank == 0` to be valid.
This patch also adds a check to `ReduceInferReturnTypes()` to ensure that the
shape inference pass doesn't crash on an invalid `axis` argument anymore.

Fix llvm#68187
  • Loading branch information
ubfx authored Oct 24, 2023
1 parent d593f6c commit 8a57bc0
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 6 deletions.
9 changes: 8 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,7 @@ def Tosa_ReduceAllOp : Tosa_InferTensorTypeOp<"reduce_all"> {
);

let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
Expand Down Expand Up @@ -1304,6 +1305,7 @@ def Tosa_ReduceAnyOp : Tosa_InferTensorTypeOp<"reduce_any"> {
);

let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
Expand Down Expand Up @@ -1337,6 +1339,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
);

let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
Expand Down Expand Up @@ -1371,6 +1374,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
);

let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
Expand Down Expand Up @@ -1405,6 +1409,7 @@ def Tosa_ReduceProdOp : Tosa_InferTensorTypeOp<"reduce_prod"> {
);

let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
Expand Down Expand Up @@ -1436,8 +1441,10 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
let results = (outs
Tosa_Tensor:$output
);
let hasFolder = 1;

let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
/// Method used by InferTypeOpInterface.
Expand Down
61 changes: 59 additions & 2 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1109,14 +1109,14 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
static LogicalResult ReduceInferReturnTypes(
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
if (!operandShape.hasRank() || operandShape.getRank() == 0) {
int64_t axisVal = axis.getValue().getSExtValue();
if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
return success();
}

SmallVector<int64_t> outputShape;
operandShape.getDims(outputShape);
int64_t axisVal = axis.getValue().getSExtValue();
outputShape[axisVal] = 1;
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
Expand Down Expand Up @@ -1155,6 +1155,63 @@ REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
#undef COMPATIBLE_RETURN_TYPES

template <typename T>
static LogicalResult verifyReduceOp(T op) {
// All TOSA reduce Ops have input, output and axis.
TensorType inputType = op.getInput().getType();
TensorType outputType = op.getOutput().getType();
int32_t reduceAxis = op.getAxis();

if (reduceAxis < 0) {
op.emitOpError("reduce axis must not be negative");
return failure();
}
if (inputType.hasRank()) {
int64_t inputRank = inputType.getRank();
// We allow for a special case where the input/output shape has rank 0 and
// axis is also 0.
if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
op.emitOpError("expect input tensor rank (")
<< inputRank << ") to be larger than reduce axis (" << reduceAxis
<< ")";
return failure();
}
}
if (outputType.hasRank()) {
int64_t outputRank = outputType.getRank();
if (inputType.hasRank() && outputRank != inputType.getRank()) {
op.emitOpError(
"expect output tensor rank to be equal to input tensor rank");
return failure();
}
if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
op.emitOpError("expect output tensor rank (")
<< outputRank << ") to be larger than reduce axis (" << reduceAxis
<< ")";
return failure();
}
// We can only verify the reduced dimension size to be 1 if this is not the
// special case of output rank == 0.
if (outputRank != 0) {
auto outputShape = outputType.getShape();
if (!outputType.isDynamicDim(reduceAxis) &&
outputShape[reduceAxis] != 1) {
op.emitOpError("expect reduced dimension size to be 1, got ")
<< outputShape[reduceAxis];
return failure();
}
}
}
return success();
}

LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); }
LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }

static LogicalResult NAryInferReturnTypes(
const ValueShapeRange &operands,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ func.func nested @fold_reduce_rank_zero() {
// CHECK-NOT: tosa.reduce_min
// CHECK-NOT: tosa.reverse
%0 = tensor.empty() : tensor<i32>
%1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
%1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
%2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
return
}
59 changes: 57 additions & 2 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,69 @@ func.func @test_reduce_min_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
// -----

func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'tosa.reduce_prod' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x5xf32>'}}
// expected-error@+1 {{'tosa.reduce_prod' op expect reduced dimension size to be 1, got 3}}
%0 = tosa.reduce_prod %arg0 {axis = 1 : i32} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
return
}

// -----

func.func @test_reduce_all_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
// expected-error@+1 {{'tosa.reduce_all' op expect input tensor rank (3) to be larger than reduce axis (3)}}
%0 = tosa.reduce_all %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
return
}

// -----

func.func @test_reduce_any_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
// expected-error@+1 {{'tosa.reduce_any' op expect input tensor rank (3) to be larger than reduce axis (3)}}
%0 = tosa.reduce_any %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
return
}

// -----

func.func @test_reduce_max_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
// expected-error@+1 {{'tosa.reduce_max' op expect input tensor rank (3) to be larger than reduce axis (3)}}
%0 = tosa.reduce_max %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
return
}

// -----

func.func @test_reduce_min_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
// expected-error@+1 {{'tosa.reduce_min' op expect input tensor rank (3) to be larger than reduce axis (3)}}
%0 = tosa.reduce_min %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
return
}

// -----

func.func @test_reduce_prod_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
// expected-error@+1 {{'tosa.reduce_prod' op expect input tensor rank (3) to be larger than reduce axis (3)}}
%0 = tosa.reduce_prod %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
return
}

// -----

func.func @test_reduce_sum_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
// expected-error@+1 {{'tosa.reduce_sum' op expect input tensor rank (3) to be larger than reduce axis (3)}}
%0 = tosa.reduce_sum %arg0 {axis = 3 : i32} : (tensor<2x3x4xf32>) -> tensor<2x3x1xf32>
return
}

// -----

func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<i32>) -> () {
// expected-error@+1 {{'tosa.reduce_min' op expect output tensor rank to be equal to input tensor rank}}
%0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
return
}

// -----

func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}
Expand Down

0 comments on commit 8a57bc0

Please sign in to comment.