Skip to content

Commit

Permalink
Move constant normalization rules into Decompose pass (onnx#2704)
Browse files Browse the repository at this point in the history
* Move constant normalization rules into Decompose pass

Signed-off-by: Tung D. Le <[email protected]>

* clang format

Signed-off-by: Tung D. Le <[email protected]>

* Edit messages

Signed-off-by: Tung D. Le <[email protected]>

---------

Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld authored Feb 6, 2024
1 parent 8436fba commit f4fddbe
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 116 deletions.
35 changes: 1 addition & 34 deletions src/Dialect/ONNX/Rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,32 +65,6 @@ DenseElementsAttr createDenseElementsAttrOfNToM(
return rewriter.getI64TensorAttr(vals);
}

Value normalizeConstantOp(
PatternRewriter &rewriter, Value output, Attribute attr) {
ShapedType outputType = output.getType().cast<ShapedType>();
Type elementType = outputType.getElementType();

DenseElementsAttr denseAttr;
if (ArrayAttr arrayAttr = attr.dyn_cast<ArrayAttr>()) {
int64_t dim = arrayAttr.size();
auto tensorType = RankedTensorType::get({dim}, elementType);
denseAttr = DenseElementsAttr::get(tensorType, arrayAttr.getValue());
} else {
auto tensorType = RankedTensorType::get({}, elementType);
if (FloatAttr floatAttr = attr.dyn_cast<FloatAttr>()) {
denseAttr = DenseElementsAttr::get(tensorType, {floatAttr.getValue()});
} else if (IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>()) {
denseAttr = DenseElementsAttr::get(tensorType, intAttr.getSInt());
} else if (StringAttr strAttr = attr.dyn_cast<StringAttr>()) {
denseAttr = DenseElementsAttr::get(tensorType, {strAttr.getValue()});
} else {
llvm_unreachable("unexpected Attribute");
}
}
OnnxBuilder createONNX(rewriter, output.getLoc());
return createONNX.constant(denseAttr);
}

// Get return type for a MatMulOp whose A's rank is N (>2) and B's rank is 2.
Type getReturnTypeForMatMulOpND2D(Value A, Value B) {
ArrayRef<int64_t> aShape = A.getType().cast<RankedTensorType>().getShape();
Expand Down Expand Up @@ -1548,14 +1522,7 @@ void ONNXCastOp::getCanonicalizationPatterns(

/// on the ONNXConstantOp.
void ONNXConstantOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.insert<ConstantOpNormalizationPattern1>(context);
results.insert<ConstantOpNormalizationPattern2>(context);
results.insert<ConstantOpNormalizationPattern3>(context);
results.insert<ConstantOpNormalizationPattern4>(context);
results.insert<ConstantOpNormalizationPattern5>(context);
results.insert<ConstantOpNormalizationPattern6>(context);
}
RewritePatternSet &results, MLIRContext *context) {}

/// on the ONNXDepthToSpaceOp.
void ONNXDepthToSpaceOp::getCanonicalizationPatterns(
Expand Down
51 changes: 1 addition & 50 deletions src/Dialect/ONNX/Rewrite.td
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,8 @@ def createDenseElementsAttrOfOneToRankOfExclusive : NativeCodeCall<
def createArrayAttrOfTwoToRankOf : NativeCodeCall<
"onnx_mlir::createArrayAttrOfNToM($_builder, 2, $0.getType().cast<ShapedType>().getRank() - 1)">;

def ONNXConstantOpNormalize: NativeCodeCall<
"onnx_mlir::normalizeConstantOp($_builder, $0, $1)">;

def AttributeIsNotNull :
Constraint<CPred<" ($_self) ">, "Attribute is null">;
Constraint<CPred<"($_self)">, "Attribute is not null">;

def IsDenseElementsAttr :
Constraint<And<[CPred<" ($_self) ">,
Expand Down Expand Up @@ -911,52 +908,6 @@ def RemoveSpaceToDepthDepthToSpacePattern : Pat<
[(Equal $bs1, $bs2), (EqualString<"CRD"> $mode)]
>;

//===----------------------------------------------------------------------===//
// Canonicalization for ONNXConstantOp
//===----------------------------------------------------------------------===//

def ConstantOpNormalizationPattern1: Pat<
(ONNXConstantOp:$res $sparseAttr, $denseAttr, $floatAttr, $floatsAttr, $intAttr,
$intsAttr, $stringAttr, $stringsAttr),
(ONNXConstantOpNormalize $res, $floatAttr),
[(AttributeIsNotNull:$floatAttr)]
>;

def ConstantOpNormalizationPattern2: Pat<
(ONNXConstantOp:$res $sparseAttr, $denseAttr, $floatAttr, $floatsAttr, $intAttr,
$intsAttr, $stringAttr, $stringsAttr),
(ONNXConstantOpNormalize $res, $intAttr),
[(AttributeIsNotNull:$intAttr)]
>;

def ConstantOpNormalizationPattern3: Pat<
(ONNXConstantOp:$res $sparseAttr, $denseAttr, $floatAttr, $floatsAttr, $intAttr,
$intsAttr, $stringAttr, $stringsAttr),
(ONNXConstantOpNormalize $res, $stringAttr),
[(AttributeIsNotNull:$stringAttr)]
>;

def ConstantOpNormalizationPattern4: Pat<
(ONNXConstantOp:$res $sparseAttr, $denseAttr, $floatAttr, $floatsAttr, $intAttr,
$intsAttr, $stringAttr, $stringsAttr),
(ONNXConstantOpNormalize $res, $floatsAttr),
[(AttributeIsNotNull:$floatsAttr)]
>;

def ConstantOpNormalizationPattern5: Pat<
(ONNXConstantOp:$res $sparseAttr, $denseAttr, $floatAttr, $floatsAttr, $intAttr,
$intsAttr, $stringAttr, $stringsAttr),
(ONNXConstantOpNormalize $res, $intsAttr),
[(AttributeIsNotNull:$intsAttr)]
>;

def ConstantOpNormalizationPattern6: Pat<
(ONNXConstantOp:$res $sparseAttr, $denseAttr, $floatAttr, $floatsAttr, $intAttr,
$intsAttr, $stringAttr, $stringsAttr),
(ONNXConstantOpNormalize $res, $stringsAttr),
[(AttributeIsNotNull:$stringsAttr)]
>;

//===----------------------------------------------------------------------===//
// Canonicalization for ONNXLessOp
//===----------------------------------------------------------------------===//
Expand Down
33 changes: 33 additions & 0 deletions src/Transform/ONNX/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,32 @@ Value insertAdditionalPadsConvTranspose(PatternRewriter &rewriter, Location loc,
}
// ConvTransposeOp END

Value normalizeConstantOp(
PatternRewriter &rewriter, Value output, Attribute attr) {
ShapedType outputType = output.getType().cast<ShapedType>();
Type elementType = outputType.getElementType();

DenseElementsAttr denseAttr;
if (ArrayAttr arrayAttr = attr.dyn_cast<ArrayAttr>()) {
int64_t dim = arrayAttr.size();
auto tensorType = RankedTensorType::get({dim}, elementType);
denseAttr = DenseElementsAttr::get(tensorType, arrayAttr.getValue());
} else {
auto tensorType = RankedTensorType::get({}, elementType);
if (FloatAttr floatAttr = attr.dyn_cast<FloatAttr>()) {
denseAttr = DenseElementsAttr::get(tensorType, {floatAttr.getValue()});
} else if (IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>()) {
denseAttr = DenseElementsAttr::get(tensorType, intAttr.getSInt());
} else if (StringAttr strAttr = attr.dyn_cast<StringAttr>()) {
denseAttr = DenseElementsAttr::get(tensorType, {strAttr.getValue()});
} else {
llvm_unreachable("unexpected Attribute");
}
}
onnx_mlir::OnnxBuilder createONNX(rewriter, output.getLoc());
return createONNX.constant(denseAttr);
}

} // namespace onnx_mlir

namespace {
Expand Down Expand Up @@ -964,6 +990,13 @@ void DecomposeONNXToONNXPass::runOnOperation() {
return !isConcatFuseMatched(op, shapeOp, transposeOp);
});

// Rewrite ONNXConstantOp with scalar values into the one using ElementAttrs.
target.addDynamicallyLegalOp<ONNXConstantOp>([](ONNXConstantOp op) {
return !(op.getValueFloatAttr() || op.getValueFloatsAttr() ||
op.getValueIntAttr() || op.getValueIntsAttr() ||
op.getValueStringAttr() || op.getValueStringsAttr());
});

// Decompose CustomOp FusedMatMul introduced by onnxruntime:
// https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul
target.addDynamicallyLegalOp<ONNXCustomOp>([](ONNXCustomOp op) {
Expand Down
51 changes: 51 additions & 0 deletions src/Transform/ONNX/Decompose.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,13 @@ def KeepdimsIsTrue
: Constraint<CPred<"$_self.cast<IntegerAttr>().getSInt() == 1">,
"keepdims attribute is true">;

def ONNXConstantOpNormalize: NativeCodeCall<
"onnx_mlir::normalizeConstantOp($_builder, $0, $1)">;

def AttributeIsNull : Constraint<CPred<"! ($_self)">, "Attribute is null">;

def AttributeIsNotNull : Constraint<CPred<"($_self)">, "Attribute is not null">;

def HasFloatType : Constraint<CPred<"(($_self).getType().dyn_cast<ShapedType>()"
".getElementType().isF32())">>;

Expand Down Expand Up @@ -549,4 +554,50 @@ def ConstantOfShapePattern: Pat<
$shape)
>;

//===----------------------------------------------------------------------===//
// Normalize ONNXConstantOp to use ElementAttrs
//===----------------------------------------------------------------------===//

def ConstantOpNormalizationPattern1: Pat<
(ONNXConstantOp:$res $sparseAttr, $denseAttr, $floatAttr, $floatsAttr, $intAttr,
$intsAttr, $stringAttr, $stringsAttr),
(ONNXConstantOpNormalize $res, $floatAttr),
[(AttributeIsNotNull:$floatAttr)]
>;

def ConstantOpNormalizationPattern2: Pat<
(ONNXConstantOp:$res $sparseAttr, $denseAttr, $floatAttr, $floatsAttr, $intAttr,
$intsAttr, $stringAttr, $stringsAttr),
(ONNXConstantOpNormalize $res, $intAttr),
[(AttributeIsNotNull:$intAttr)]
>;

def ConstantOpNormalizationPattern3: Pat<
(ONNXConstantOp:$res $sparseAttr, $denseAttr, $floatAttr, $floatsAttr, $intAttr,
$intsAttr, $stringAttr, $stringsAttr),
(ONNXConstantOpNormalize $res, $stringAttr),
[(AttributeIsNotNull:$stringAttr)]
>;

def ConstantOpNormalizationPattern4: Pat<
(ONNXConstantOp:$res $sparseAttr, $denseAttr, $floatAttr, $floatsAttr, $intAttr,
$intsAttr, $stringAttr, $stringsAttr),
(ONNXConstantOpNormalize $res, $floatsAttr),
[(AttributeIsNotNull:$floatsAttr)]
>;

def ConstantOpNormalizationPattern5: Pat<
(ONNXConstantOp:$res $sparseAttr, $denseAttr, $floatAttr, $floatsAttr, $intAttr,
$intsAttr, $stringAttr, $stringsAttr),
(ONNXConstantOpNormalize $res, $intsAttr),
[(AttributeIsNotNull:$intsAttr)]
>;

def ConstantOpNormalizationPattern6: Pat<
(ONNXConstantOp:$res $sparseAttr, $denseAttr, $floatAttr, $floatsAttr, $intAttr,
$intsAttr, $stringAttr, $stringsAttr),
(ONNXConstantOpNormalize $res, $stringsAttr),
[(AttributeIsNotNull:$stringsAttr)]
>;

#endif // ONNX_DECOMPOSE
32 changes: 0 additions & 32 deletions test/mlir/onnx/onnx_canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -680,38 +680,6 @@ func.func @test_remove_space_to_depth_depth_to_space(%arg0 : tensor<1x256x8x16xf

// -----

func.func @test_constant_1() -> tensor<i64> {
%0 = onnx.Constant {value_int = 1 : si64} : tensor<i64>
onnx.Return %0 : tensor<i64>
// CHECK-LABEL: func @test_constant_1
// CHECK: [[VAR_0:%.+]] = onnx.Constant dense<1> : tensor<i64>
// CHECK: onnx.Return [[VAR_0]] : tensor<i64>
}


// -----

func.func @test_constant_2() -> tensor<f32> {
%0 = onnx.Constant {value_float = 2.0 : f32 } : tensor<f32>
onnx.Return %0 : tensor<f32>
// CHECK-LABEL: func @test_constant_2
// CHECK: [[VAR_0:%.+]] = onnx.Constant dense<2.000000e+00> : tensor<f32>
// CHECK: onnx.Return [[VAR_0]] : tensor<f32>
}

// -----

func.func @test_constant_3() -> tensor<3xi64> {
%0 = onnx.Constant {value_ints = [1, 2, 3] } : tensor<3xi64>
onnx.Return %0 : tensor<3xi64>
// CHECK-LABEL: func @test_constant_3
// CHECK-SAME: () -> tensor<3xi64> {
// CHECK: [[VAR_0:%.+]] = onnx.Constant dense<[1, 2, 3]> : tensor<3xi64>
// CHECK: onnx.Return [[VAR_0]] : tensor<3xi64>
}

// -----

func.func @test_rewrite_batchnormtestmode_Nd(%arg0 : tensor<1x64x112x112xf32>, %scale : tensor<64xf32>, %bias : tensor<64xf32>, %mean : tensor<64xf32>, %var : tensor<64xf32>) -> tensor<1x64x112x112xf32> {
%0 = "onnx.BatchNormalizationInferenceMode"(%arg0, %scale, %bias, %mean, %var) {epsilon = 1.00000007E-5 : f32} : (tensor<1x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32>
onnx.Return %0 : tensor<1x64x112x112xf32>
Expand Down
34 changes: 34 additions & 0 deletions test/mlir/onnx/onnx_decompose.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -555,3 +555,37 @@ func.func @test_instancenorm(%arg0: tensor<2x3x4x5x6xf32>, %arg1: tensor<3xf32>,
// CHECK: onnx.Return [[Y_]] : tensor<2x3x4x5x6xf32>
// CHECK: }
}

// -----

func.func @test_constant_1() -> tensor<i64> {
%0 = onnx.Constant {value_int = 1 : si64} : tensor<i64>
onnx.Return %0 : tensor<i64>
// CHECK-LABEL: func @test_constant_1
// CHECK: [[VAR_0:%.+]] = onnx.Constant dense<1> : tensor<i64>
// CHECK: onnx.Return [[VAR_0]] : tensor<i64>
}


// -----

func.func @test_constant_2() -> tensor<f32> {
%0 = onnx.Constant {value_float = 2.0 : f32 } : tensor<f32>
onnx.Return %0 : tensor<f32>
// CHECK-LABEL: func @test_constant_2
// CHECK: [[VAR_0:%.+]] = onnx.Constant dense<2.000000e+00> : tensor<f32>
// CHECK: onnx.Return [[VAR_0]] : tensor<f32>
}

// -----

func.func @test_constant_3() -> tensor<3xi64> {
%0 = onnx.Constant {value_ints = [1, 2, 3] } : tensor<3xi64>
onnx.Return %0 : tensor<3xi64>
// CHECK-LABEL: func @test_constant_3
// CHECK-SAME: () -> tensor<3xi64> {
// CHECK: [[VAR_0:%.+]] = onnx.Constant dense<[1, 2, 3]> : tensor<3xi64>
// CHECK: onnx.Return [[VAR_0]] : tensor<3xi64>
}

// -----

0 comments on commit f4fddbe

Please sign in to comment.