Skip to content

Commit

Permalink
Skip the default quantization parameters inside the custom op
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 375010149
Change-Id: I9496b18097e93a16a5291c4bd9cd35416f90b95e
  • Loading branch information
liufengdb authored and tensorflower-gardener committed May 21, 2021
1 parent 1877c2f commit ed56215
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
24 changes: 24 additions & 0 deletions tensorflow/compiler/mlir/lite/tests/default_quant_params.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,27 @@ func @test_conv_2d_activation_and_bias(%arg0: tensor<1x224x224x3xf32>, %arg1: te
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[conv]]) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.0078431372549019607:128>>)
// CHECK: return %[[dq]]
}

// CHECK-LABEL: test_region
func @test_region(%arg0: tensor<128x128x!quant.uniform<u8:f32, 0.1:127>>,
%arg1: tensor<1x!quant.uniform<u8:f32, 0.2:127>>, %arg2: tensor<1x!quant.uniform<u8:f32, 0.4:127>>,
%arg3: tensor<1xi32>) -> (tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>) {
%0 = "tfl.dequantize"(%arg0) : (tensor<128x128x!quant.uniform<u8:f32, 0.1:127>>) -> tensor<128x128xf32>
%1 = "tfl.dequantize"(%arg1) : (tensor<1x!quant.uniform<u8:f32, 0.2:127>>) -> tensor<1xf32>
%2 = "tfl.dequantize"(%arg2) : (tensor<1x!quant.uniform<u8:f32, 0.4:127>>) -> tensor<1xf32>
%3 = "tfl.custom_tf"(%0, %1, %2, %arg3) ( {
^bb0(%a1: tensor<128x128xf32>, %a2: tensor<1xf32>, %a3: tensor<1xf32>, %a4: tensor<1xi32>): // no predecessors
%4 = "tf.LayerNorm"(%a1, %a2, %a3, %a4) {_tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
"tfl.yield"(%4) : (tensor<128x128xf32>) -> ()
}) {_tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
%4 = "tfl.quantize"(%3) {qtype = tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>} : (tensor<128x128xf32>) -> tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>
return %4 : tensor<128x128x!quant.uniform<u8:f32, 0.2:125>>

// CHECK: "tfl.custom_tf"
// CHECK-NEXT: ^bb0(%arg4: tensor<128x128xf32>, %arg5: tensor<1xf32>, %arg6: tensor<1xf32>, %arg7: tensor<1xi32>): // no predecessors
// CHECK-NEXT: "tf.LayerNorm"(%arg4, %arg5, %arg6, %arg7) {_tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
// CHECK-NEXT: "tfl.yield"
// CHECK-NEXT: }) {_tfl_quant_trait = "fully_quantizable", device = ""} :
// CHECK-SAME: (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
}

Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ void DefaultQuantParamsPass::runOnFunction() {
}

func.walk([&](Operation *op) {
if (op->hasTrait<OpTrait::IsTerminator>() ||
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp>(op))
if (quant::IsOpNotQuantizable(op) ||
op->getParentOfType<TFL::CustomTfOp>()) {
return;
}

for (auto res : op->getResults()) {
if (UsedAsBias(res)) {
Expand Down

0 comments on commit ed56215

Please sign in to comment.