diff --git a/tensorflow/compiler/mlir/lite/tests/default_quant_params.mlir b/tensorflow/compiler/mlir/lite/tests/default_quant_params.mlir index 31b6e7968cdede..7ddd4baaee5827 100644 --- a/tensorflow/compiler/mlir/lite/tests/default_quant_params.mlir +++ b/tensorflow/compiler/mlir/lite/tests/default_quant_params.mlir @@ -87,3 +87,27 @@ func @test_conv_2d_activation_and_bias(%arg0: tensor<1x224x224x3xf32>, %arg1: te // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[conv]]) : (tensor<1x112x112x32x!quant.uniform>) // CHECK: return %[[dq]] } + +// CHECK-LABEL: test_region +func @test_region(%arg0: tensor<128x128x!quant.uniform>, + %arg1: tensor<1x!quant.uniform>, %arg2: tensor<1x!quant.uniform>, + %arg3: tensor<1xi32>) -> (tensor<128x128x!quant.uniform>) { + %0 = "tfl.dequantize"(%arg0) : (tensor<128x128x!quant.uniform>) -> tensor<128x128xf32> + %1 = "tfl.dequantize"(%arg1) : (tensor<1x!quant.uniform>) -> tensor<1xf32> + %2 = "tfl.dequantize"(%arg2) : (tensor<1x!quant.uniform>) -> tensor<1xf32> + %3 = "tfl.custom_tf"(%0, %1, %2, %arg3) ( { + ^bb0(%a1: tensor<128x128xf32>, %a2: tensor<1xf32>, %a3: tensor<1xf32>, %a4: tensor<1xi32>): // no predecessors + %4 = "tf.LayerNorm"(%a1, %a2, %a3, %a4) {_tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32> + "tfl.yield"(%4) : (tensor<128x128xf32>) -> () + }) {_tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32> + %4 = "tfl.quantize"(%3) {qtype = tensor<128x128x!quant.uniform>} : (tensor<128x128xf32>) -> tensor<128x128x!quant.uniform> + return %4 : tensor<128x128x!quant.uniform> + +// CHECK: "tfl.custom_tf" +// CHECK-NEXT: ^bb0(%arg4: tensor<128x128xf32>, %arg5: tensor<1xf32>, %arg6: tensor<1xf32>, %arg7: tensor<1xi32>): // no predecessors +// CHECK-NEXT: "tf.LayerNorm"(%arg4, %arg5, %arg6, %arg7) {_tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32> +// CHECK-NEXT: "tfl.yield" +// CHECK-NEXT: }) {_tfl_quant_trait = "fully_quantizable", device = ""} : +// CHECK-SAME: (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32> +} + diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index c4474d4c6e37ed..b66d46b6a3fea1 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -109,10 +109,10 @@ void DefaultQuantParamsPass::runOnFunction() { } func.walk([&](Operation *op) { - if (op->hasTrait() || - op->hasTrait() || - llvm::isa(op)) + if (quant::IsOpNotQuantizable(op) || + op->getParentOfType()) { return; + } for (auto res : op->getResults()) { if (UsedAsBias(res)) {