|
16 | 16 | # under the License.
|
17 | 17 | # pylint: disable=invalid-name, import-self, import-outside-toplevel
|
18 | 18 | """Keras frontend."""
|
| 19 | +import dis |
19 | 20 | import sys
|
20 | 21 | import numpy as np
|
21 | 22 | import tvm
|
@@ -988,10 +989,110 @@ def _convert_repeat_vector(inexpr, keras_layer, _):
|
988 | 989 | out_shape = [-1, repeats] + input_shape[1:]
|
989 | 990 | out = _op.repeat(inexpr, repeats=repeats, axis=0)
|
990 | 991 | out = _op.reshape(out, out_shape)
|
991 |
| - |
992 | 992 | return out
|
993 | 993 |
|
994 | 994 |
|
| 995 | +def _convert_l2_normalize(inexpr, keras_layer, etab): |
| 996 | + l2_normalize_is_loaded = False |
| 997 | + param_list = [] |
| 998 | + for i in dis.get_instructions(keras_layer.function): |
| 999 | + if i.opname in ["LOAD_GLOBAL", "LOAD_DEREF"]: |
| 1000 | + continue |
| 1001 | + if i.opname in ["LOAD_ATTR", "LOAD_METHOD"]: |
| 1002 | + if i.argval == "l2_normalize": |
| 1003 | + assert not l2_normalize_is_loaded, "l2_normalize was already LOADED" |
| 1004 | + l2_normalize_is_loaded = True |
| 1005 | + elif i.opname in ["LOAD_CONST", "LOAD_FAST"] and l2_normalize_is_loaded: |
| 1006 | + param_list.append(i.argval) |
| 1007 | + elif i.opname == "BUILD_LIST": |
| 1008 | + sz = i.argval |
| 1009 | + assert len(param_list) >= sz |
| 1010 | + new_list = param_list[-sz:] |
| 1011 | + param_list = param_list[:-sz] |
| 1012 | + param_list.append(new_list) |
| 1013 | + elif i.opname in ["CALL_FUNCTION_KW", "CALL_METHOD"]: |
| 1014 | + break |
| 1015 | + |
| 1016 | + axis = None |
| 1017 | + is_param_list_parsed = False |
| 1018 | + if l2_normalize_is_loaded and len(param_list) > 0: |
| 1019 | + # last param_list item is tuple of strings means that |
| 1020 | + # lambda uses named parameters when calling l2_normalize |
| 1021 | + if ( |
| 1022 | + isinstance(param_list[-1], tuple) |
| 1023 | + and len(param_list[-1]) > 0 |
| 1024 | + and isinstance(param_list[-1][0], str) |
| 1025 | + ): |
| 1026 | + param_names = param_list[-1] |
| 1027 | + if len(param_names) == 1 and param_names[0] == "x": |
| 1028 | + # lambda v: K.l2_normalize(x=v) |
| 1029 | + axis = None |
| 1030 | + is_param_list_parsed = True |
| 1031 | + elif len(param_names) == 1 and param_names[0] == "axis" and len(param_list) == 3: |
| 1032 | + # lambda x: K.l2_normalize(x, axis=(2,3)) |
| 1033 | + axis = param_list[1] |
| 1034 | + is_param_list_parsed = True |
| 1035 | + elif len(param_names) == 2 and len(param_list) == 3: |
| 1036 | + # lambda x: K.l2_normalize(x=x, axis=(2,3)) |
| 1037 | + # lambda x: K.l2_normalize(axis=(2,3), x=x) |
| 1038 | + axis = param_list[param_names.index("axis")] |
| 1039 | + is_param_list_parsed = True |
| 1040 | + else: |
| 1041 | + # lambda x: K.l2_normalize(x) |
| 1042 | + if len(param_list) == 1: |
| 1043 | + axis = None |
| 1044 | + is_param_list_parsed = True |
| 1045 | + # lambda x: K.l2_normalize(x, (2,3)) |
| 1046 | + elif len(param_list) == 2: |
| 1047 | + axis = param_list[1] |
| 1048 | + is_param_list_parsed = True |
| 1049 | + |
| 1050 | + def is_int_or_tuple_of_ints(v): |
| 1051 | + if isinstance(v, list) and len(v) > 0: |
| 1052 | + for i in v: |
| 1053 | + if not isinstance(i, int): |
| 1054 | + return False |
| 1055 | + return True |
| 1056 | + if isinstance(v, tuple) and len(v) > 0: |
| 1057 | + return isinstance(v[0], int) |
| 1058 | + return isinstance(v, int) |
| 1059 | + |
| 1060 | + assert is_param_list_parsed and ( |
| 1061 | + axis is None or is_int_or_tuple_of_ints(axis) |
| 1062 | + ), "Can not parse l2_normalize lambda function found in Lambda layer" |
| 1063 | + if isinstance(axis, int): |
| 1064 | + axis = [axis] |
| 1065 | + |
| 1066 | + if etab.data_layout == "NCHW": |
| 1067 | + dims = len(keras_layer.input_shape) |
| 1068 | + |
| 1069 | + def fix_axis_for_nchw(axis): |
| 1070 | + if axis == 0: |
| 1071 | + return 0 |
| 1072 | + if axis in [(dims - 1), -1]: |
| 1073 | + return 1 |
| 1074 | + return axis + 1 |
| 1075 | + |
| 1076 | + axis = [fix_axis_for_nchw(x) for x in axis] |
| 1077 | + return _op.nn.l2_normalize(inexpr, eps=1e-12, axis=axis) |
| 1078 | + |
| 1079 | + |
| 1080 | +def _convert_lambda(inexpr, keras_layer, etab): |
| 1081 | + fcode = keras_layer.function.__code__ |
| 1082 | + # Convert l2_normalize |
| 1083 | + if ( |
| 1084 | + fcode.co_name == "<lambda>" |
| 1085 | + and len(fcode.co_names) > 0 |
| 1086 | + and fcode.co_names[-1] == "l2_normalize" |
| 1087 | + ): |
| 1088 | + return _convert_l2_normalize(inexpr, keras_layer, etab) |
| 1089 | + raise tvm.error.OpNotImplemented( |
| 1090 | + "Function {} used in Lambda layer is not supported in frontend Keras.".format( |
| 1091 | + fcode.co_names |
| 1092 | + ) |
| 1093 | + ) |
| 1094 | + |
| 1095 | + |
995 | 1096 | def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument
|
996 | 1097 | """Layers that can be skipped because they are train time only."""
|
997 | 1098 | return inexpr
|
@@ -1056,6 +1157,7 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument
|
1056 | 1157 | "Permute": _convert_permute,
|
1057 | 1158 | "Embedding": _convert_embedding,
|
1058 | 1159 | "RepeatVector": _convert_repeat_vector,
|
| 1160 | + "Lambda": _convert_lambda, |
1059 | 1161 | "InputLayer": _default_skip,
|
1060 | 1162 | "Dropout": _default_skip,
|
1061 | 1163 | "AlphaDropout": _default_skip,
|
|
0 commit comments