Skip to content

Commit

Permalink
[Frontend][TFLite] Add parser support for l2_normalization (apache#4966)
Browse files Browse the repository at this point in the history
* [Frontend][TFLite] Add parser support for l2_normalization

* TF doesn't provide uint8 support
* TFL does the normalization only if it's over the last axis
* TFL uses only the default value for expilon

* Change error message
  • Loading branch information
inadob authored Feb 29, 2020
1 parent a449d8b commit 2355caa
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
47 changes: 47 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(self, model, subgraph, exp_tab):
'LOGICAL_OR': self.convert_logical_or,
'DETECTION_POSTPROCESS': self.convert_detection_postprocess,
'SQUARE': self.convert_square,
'L2_NORMALIZATION': self.convert_l2_normalization,
}

def check_unsupported_ops(self):
Expand Down Expand Up @@ -405,6 +406,52 @@ def convert_resize_nearest_neighbor(self, op):
"""Convert TFLite RESIZE_NEAREST_NEIGHBOR"""
return self._convert_resize("nearest_neighbor", op)

def convert_l2_normalization(self, op):
"""Convert TFLite L2_NORMALIZATION """
try:
from tflite.Operator import Operator
from tflite.BuiltinOptions import BuiltinOptions
from tflite.L2NormOptions import L2NormOptions
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

assert op.BuiltinOptionsType() == BuiltinOptions.L2NormOptions
op_options = op.BuiltinOptions()
l2_norm_options = L2NormOptions()
l2_norm_options.Init(op_options.Bytes, op_options.Pos)
fused_activation_fn = l2_norm_options.FusedActivationFunction()

# TFLite supports normalization only over the last dim
input_tensor_rank = len(input_tensor.tensor.ShapeAsNumpy())

if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFLite quantized L2_NORMALIZATION operator is not supported yet.')
# TFL uses only the default epsilon value
out = _op.nn.l2_normalize(in_expr, eps=1e-12, axis=[input_tensor_rank - 1])

# if we have fused activation fn
if fused_activation_fn != ActivationFunctionType.NONE:
if not output_tensor.qnn_params:
out = self.convert_fused_activation_function(out, fused_activation_fn)
else:
raise tvm.error.OpNotImplemented(
'TFLite quantized L2_NORMALIZATION operator\
with fused activation function is not supported yet.')

return out

def convert_logistic(self, op):
"""Convert TFLite LOGISTIC"""
try:
Expand Down
20 changes: 20 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import variables
try:
from tensorflow import lite as interpreter_wrapper
Expand Down Expand Up @@ -1263,6 +1264,24 @@ def test_forward_unpack():
_test_unpack(np.array(np.random.uniform(0, 5, (3, 6)), dtype=np.int32), axis=-2, num_unpacks=3)
_test_unpack(np.array(np.random.uniform(0, 5, (2, 3, 4)), dtype=np.int32), axis=-3, num_unpacks=2)

#######################################################################
# L2 normalization
# ----------------

def _test_l2_normalization(data, axis, fused_activation_function=None):
""" One iteration of L2_NORMALIZATION """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = nn_impl.l2_normalize(in_data, axis)
out = with_fused_activation_function(out, fused_activation_function)
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])

def test_forward_l2_normalization():
""" L2_NORMALIZATION """
data = np.random.uniform(size=(3, 6, 4)).astype('float32')
_test_l2_normalization(data, axis=2)
_test_l2_normalization(data, axis=2, fused_activation_function="RELU")

#######################################################################
# Logistic
# --------
Expand Down Expand Up @@ -1649,6 +1668,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_relu()
test_forward_prelu()
test_forward_fully_connected()
test_forward_l2_normalization()

# Elemwise
test_all_elemwise()
Expand Down

0 comments on commit 2355caa

Please sign in to comment.