diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index b8c261533fe477..b6471728f947f9 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -265,12 +265,14 @@ cc_library( "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc index 91a633745b8834..f7a136f2259ad2 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc @@ -22,18 +22,27 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "absl/strings/ascii.h" #include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" @@ -89,29 +98,39 @@ void PrintOpStatsPass::runOnOperation() { auto op_name = op->getName().stripDialect(); auto dialect_name = op->getDialect()->getNamespace(); - if (op->getNumResults() > 0) { - if (auto shaped_type = - op->getResult(0).getType().dyn_cast_or_null()) { - auto result = shaped_type.getElementType(); + if (op->getNumResults() > 0 && + isa(op->getResult(0).getType())) { + // Use rhs operand to detect types for dynamic range quantizable ops. + Value value_for_deducing_op_type = + (dyn_cast_or_null(op)) + ? op->getOperand(1) + : op->getResult(0); + ShapedType value_shaped_type = + value_for_deducing_op_type.getType().dyn_cast_or_null(); + if (value_shaped_type != nullptr) { + auto operand_or_result = value_shaped_type.getElementType(); std::string dtype; - TypeSwitch(result) + TypeSwitch(operand_or_result) .Case([&](Type) { - dtype = absl::StrCat("i", result.getIntOrFloatBitWidth()); + dtype = + absl::StrCat("i", operand_or_result.getIntOrFloatBitWidth()); }) .Case([&](Type) { - dtype = absl::StrCat("f", result.getIntOrFloatBitWidth()); + dtype = + absl::StrCat("f", operand_or_result.getIntOrFloatBitWidth()); }) .Case([&](Type) { auto uniform_quantized_dtype = - result.dyn_cast_or_null() + operand_or_result.dyn_cast_or_null() .getStorageType(); dtype = absl::StrCat( "uq_", uniform_quantized_dtype.getIntOrFloatBitWidth()); }) .Case([&](Type) { auto uniform_quantized_dtype = - result.dyn_cast_or_null() + operand_or_result + .dyn_cast_or_null() .getStorageType(); dtype = absl::StrCat( "uq_", uniform_quantized_dtype.getIntOrFloatBitWidth());