Skip to content

Commit bff9884

Browse files
authored
[Keras] Add l2_normalize support (apache#9383)
1 parent 8ada2b1 commit bff9884

File tree

3 files changed

+126
-2
lines changed

3 files changed

+126
-2
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ perf
174174
.bash_history
175175
*.json
176176
*.params
177+
*.ro
177178
*.onnx
178179
*.h5
179180
synset.txt
@@ -240,4 +241,4 @@ conda/pkg
240241
# Downloaded models/datasets
241242
.tvm_test_data
242243
.dgl
243-
.caffe2
244+
.caffe2

python/tvm/relay/frontend/keras.py

+103-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
# pylint: disable=invalid-name, import-self, import-outside-toplevel
1818
"""Keras frontend."""
19+
import dis
1920
import sys
2021
import numpy as np
2122
import tvm
@@ -988,10 +989,110 @@ def _convert_repeat_vector(inexpr, keras_layer, _):
988989
out_shape = [-1, repeats] + input_shape[1:]
989990
out = _op.repeat(inexpr, repeats=repeats, axis=0)
990991
out = _op.reshape(out, out_shape)
991-
992992
return out
993993

994994

995+
def _convert_l2_normalize(inexpr, keras_layer, etab):
996+
l2_normalize_is_loaded = False
997+
param_list = []
998+
for i in dis.get_instructions(keras_layer.function):
999+
if i.opname in ["LOAD_GLOBAL", "LOAD_DEREF"]:
1000+
continue
1001+
if i.opname in ["LOAD_ATTR", "LOAD_METHOD"]:
1002+
if i.argval == "l2_normalize":
1003+
assert not l2_normalize_is_loaded, "l2_normalize was already LOADED"
1004+
l2_normalize_is_loaded = True
1005+
elif i.opname in ["LOAD_CONST", "LOAD_FAST"] and l2_normalize_is_loaded:
1006+
param_list.append(i.argval)
1007+
elif i.opname == "BUILD_LIST":
1008+
sz = i.argval
1009+
assert len(param_list) >= sz
1010+
new_list = param_list[-sz:]
1011+
param_list = param_list[:-sz]
1012+
param_list.append(new_list)
1013+
elif i.opname in ["CALL_FUNCTION_KW", "CALL_METHOD"]:
1014+
break
1015+
1016+
axis = None
1017+
is_param_list_parsed = False
1018+
if l2_normalize_is_loaded and len(param_list) > 0:
1019+
# last param_list item is tuple of strings means that
1020+
# lambda uses named parameters when calling l2_normalize
1021+
if (
1022+
isinstance(param_list[-1], tuple)
1023+
and len(param_list[-1]) > 0
1024+
and isinstance(param_list[-1][0], str)
1025+
):
1026+
param_names = param_list[-1]
1027+
if len(param_names) == 1 and param_names[0] == "x":
1028+
# lambda v: K.l2_normalize(x=v)
1029+
axis = None
1030+
is_param_list_parsed = True
1031+
elif len(param_names) == 1 and param_names[0] == "axis" and len(param_list) == 3:
1032+
# lambda x: K.l2_normalize(x, axis=(2,3))
1033+
axis = param_list[1]
1034+
is_param_list_parsed = True
1035+
elif len(param_names) == 2 and len(param_list) == 3:
1036+
# lambda x: K.l2_normalize(x=x, axis=(2,3))
1037+
# lambda x: K.l2_normalize(axis=(2,3), x=x)
1038+
axis = param_list[param_names.index("axis")]
1039+
is_param_list_parsed = True
1040+
else:
1041+
# lambda x: K.l2_normalize(x)
1042+
if len(param_list) == 1:
1043+
axis = None
1044+
is_param_list_parsed = True
1045+
# lambda x: K.l2_normalize(x, (2,3))
1046+
elif len(param_list) == 2:
1047+
axis = param_list[1]
1048+
is_param_list_parsed = True
1049+
1050+
def is_int_or_tuple_of_ints(v):
1051+
if isinstance(v, list) and len(v) > 0:
1052+
for i in v:
1053+
if not isinstance(i, int):
1054+
return False
1055+
return True
1056+
if isinstance(v, tuple) and len(v) > 0:
1057+
return isinstance(v[0], int)
1058+
return isinstance(v, int)
1059+
1060+
assert is_param_list_parsed and (
1061+
axis is None or is_int_or_tuple_of_ints(axis)
1062+
), "Can not parse l2_normalize lambda function found in Lambda layer"
1063+
if isinstance(axis, int):
1064+
axis = [axis]
1065+
1066+
if etab.data_layout == "NCHW":
1067+
dims = len(keras_layer.input_shape)
1068+
1069+
def fix_axis_for_nchw(axis):
1070+
if axis == 0:
1071+
return 0
1072+
if axis in [(dims - 1), -1]:
1073+
return 1
1074+
return axis + 1
1075+
1076+
axis = [fix_axis_for_nchw(x) for x in axis]
1077+
return _op.nn.l2_normalize(inexpr, eps=1e-12, axis=axis)
1078+
1079+
1080+
def _convert_lambda(inexpr, keras_layer, etab):
1081+
fcode = keras_layer.function.__code__
1082+
# Convert l2_normalize
1083+
if (
1084+
fcode.co_name == "<lambda>"
1085+
and len(fcode.co_names) > 0
1086+
and fcode.co_names[-1] == "l2_normalize"
1087+
):
1088+
return _convert_l2_normalize(inexpr, keras_layer, etab)
1089+
raise tvm.error.OpNotImplemented(
1090+
"Function {} used in Lambda layer is not supported in frontend Keras.".format(
1091+
fcode.co_names
1092+
)
1093+
)
1094+
1095+
9951096
def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument
9961097
"""Layers that can be skipped because they are train time only."""
9971098
return inexpr
@@ -1056,6 +1157,7 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument
10561157
"Permute": _convert_permute,
10571158
"Embedding": _convert_embedding,
10581159
"RepeatVector": _convert_repeat_vector,
1160+
"Lambda": _convert_lambda,
10591161
"InputLayer": _default_skip,
10601162
"Dropout": _default_skip,
10611163
"AlphaDropout": _default_skip,

tests/python/frontend/keras/test_forward.py

+21
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,26 @@ def test_forward_nested_layers(self, keras):
604604
)
605605
verify_keras_frontend(keras_model)
606606

607+
def test_forward_l2_normalize(self, keras):
608+
data = keras.layers.Input(shape=(16, 12, 8))
609+
K = keras.backend
610+
l2_funcs = [
611+
keras.layers.Lambda(lambda v: K.l2_normalize(v, axis=-2)),
612+
keras.layers.Lambda(lambda v: K.l2_normalize(x=v, axis=-1)),
613+
keras.layers.Lambda(lambda v: K.l2_normalize(axis=1, x=v)),
614+
keras.layers.Lambda(lambda v: K.l2_normalize(v, 2)),
615+
keras.layers.Lambda(lambda v: K.l2_normalize(v, axis=3)),
616+
keras.layers.Lambda(lambda v: K.l2_normalize(v, axis=(2, 3))),
617+
keras.layers.Lambda(lambda v: K.l2_normalize(v, (1, 2))),
618+
keras.layers.Lambda(lambda v: K.l2_normalize(v, axis=[-2, -1])),
619+
keras.layers.Lambda(lambda v: K.l2_normalize(v, [-3, -2])),
620+
]
621+
for l2_func in l2_funcs:
622+
x = l2_func(data)
623+
keras_model = keras.models.Model(data, x)
624+
verify_keras_frontend(keras_model, layout="NCHW")
625+
verify_keras_frontend(keras_model, layout="NHWC")
626+
607627

608628
if __name__ == "__main__":
609629
for k in [keras, tf_keras]:
@@ -641,3 +661,4 @@ def test_forward_nested_layers(self, keras):
641661
sut.test_forward_zero_padding3d(keras=k)
642662
sut.test_forward_embedding(keras=k)
643663
sut.test_forward_repeat_vector(keras=k)
664+
sut.test_forward_l2_normalize(keras=k)

0 commit comments

Comments
 (0)