Skip to content

Commit

Permalink
Check in seq_flow_lite (tensorflow#10502)
Browse files Browse the repository at this point in the history
  • Loading branch information
pyoung2778 authored Feb 15, 2022
1 parent 3d0e12f commit 79b6c2f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
6 changes: 1 addition & 5 deletions research/seq_flow_lite/.bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ build:manylinux2010 --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.

build -c opt
build --cxxopt="-std=c++14"
build --host_cxxopt="-std=c++14"
build --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"
build --auto_output_filter=subpackages
build --copt="-Wall" --copt="-Wno-sign-compare"
Expand Down Expand Up @@ -60,11 +61,6 @@ build --define=grpc_no_ares=true
# archives in -whole_archive -no_whole_archive.
build --noincompatible_remove_legacy_whole_archive

# These are bazel 2.0's incompatible flags. Tensorflow needs to use bazel 2.0.0
# to use cc_shared_library, as part of the Tensorflow Build Improvements RFC:
# https://github.com/tensorflow/community/pull/179
build --noincompatible_prohibit_aapt1

# Build TF with C++ 17 features.
build:c++17 --cxxopt=-std=c++1z
build:c++17 --cxxopt=-stdlib=libc++
Expand Down
34 changes: 34 additions & 0 deletions research/seq_flow_lite/utils/tflite_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,40 @@ def _dump_graph_in_text_format(filename, graph_def):
f.close()


def get_mean_stddev_values(min_value_of_features, max_value_of_features):
"""Gets Mean and Stddev values for given min/max float values."""
quant_min = 0
quant_max = 255

min_global = min_value_of_features
max_global = max_value_of_features

quant_min_float = float(quant_min)
quant_max_float = float(quant_max)

nudged_scale = (max_global - min_global) / (quant_max_float - quant_min_float)

zero_point_from_min = quant_min_float - min_global / nudged_scale

if zero_point_from_min < quant_min_float:
nudged_zero_point = int(quant_min)
elif zero_point_from_min > quant_max_float:
nudged_zero_point = int(quant_max)
else:
nudged_zero_point = int(round(zero_point_from_min))

nudged_min = (quant_min_float - nudged_zero_point) * (nudged_scale)
nudged_max = (quant_max_float - nudged_zero_point) * (nudged_scale)

zero_point = (quant_min - min_global) / (max_global - min_global) * quant_max
scale = (nudged_max - nudged_min) / 255.0

mean_value = zero_point
stddev_value = 1 / scale

return mean_value, stddev_value


class InterpreterWithCustomOps(tf.lite.Interpreter):
"""Extended tf.lite.Interpreter."""

Expand Down

0 comments on commit 79b6c2f

Please sign in to comment.