Skip to content

Commit

Permalink
Add Dynamic Range Quantized op support for op_stat_pass.cc.
Browse files Browse the repository at this point in the history
- Cleanup header imports as well.

PiperOrigin-RevId: 614784461
  • Loading branch information
chococigar authored and tensorflower-gardener committed Mar 11, 2024
1 parent 0318011 commit c3fd2b6
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/lite/stablehlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -265,12 +265,14 @@ cc_library(
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/quantization/common/quantization_lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)
Expand Down
37 changes: 28 additions & 9 deletions tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,27 @@ limitations under the License.
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/log/log.h"
#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/TypeID.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h"
Expand Down Expand Up @@ -89,29 +98,39 @@ void PrintOpStatsPass::runOnOperation() {
auto op_name = op->getName().stripDialect();
auto dialect_name = op->getDialect()->getNamespace();

if (op->getNumResults() > 0) {
if (auto shaped_type =
op->getResult(0).getType().dyn_cast_or_null<ShapedType>()) {
auto result = shaped_type.getElementType();
if (op->getNumResults() > 0 &&
isa<ShapedType>(op->getResult(0).getType())) {
// Use rhs operand to detect types for dynamic range quantizable ops.
Value value_for_deducing_op_type =
(dyn_cast_or_null<DynamicRangeQuantizedOpInterface>(op))
? op->getOperand(1)
: op->getResult(0);
ShapedType value_shaped_type =
value_for_deducing_op_type.getType().dyn_cast_or_null<ShapedType>();
if (value_shaped_type != nullptr) {
auto operand_or_result = value_shaped_type.getElementType();
std::string dtype;

TypeSwitch<Type>(result)
TypeSwitch<Type>(operand_or_result)
.Case<IntegerType>([&](Type) {
dtype = absl::StrCat("i", result.getIntOrFloatBitWidth());
dtype =
absl::StrCat("i", operand_or_result.getIntOrFloatBitWidth());
})
.Case<FloatType>([&](Type) {
dtype = absl::StrCat("f", result.getIntOrFloatBitWidth());
dtype =
absl::StrCat("f", operand_or_result.getIntOrFloatBitWidth());
})
.Case<UniformQuantizedType>([&](Type) {
auto uniform_quantized_dtype =
result.dyn_cast_or_null<UniformQuantizedType>()
operand_or_result.dyn_cast_or_null<UniformQuantizedType>()
.getStorageType();
dtype = absl::StrCat(
"uq_", uniform_quantized_dtype.getIntOrFloatBitWidth());
})
.Case<quant::UniformQuantizedPerAxisType>([&](Type) {
auto uniform_quantized_dtype =
result.dyn_cast_or_null<quant::UniformQuantizedPerAxisType>()
operand_or_result
.dyn_cast_or_null<quant::UniformQuantizedPerAxisType>()
.getStorageType();
dtype = absl::StrCat(
"uq_", uniform_quantized_dtype.getIntOrFloatBitWidth());
Expand Down

0 comments on commit c3fd2b6

Please sign in to comment.