Skip to content

Commit

Permalink
Improve the generated tfr methods
Browse files Browse the repository at this point in the history
This patch added two improvments:

- allow the dtype to be assigned to the port directly in the TF op definition. These dtypes will be converted to speical attributes and we add them to the composition function directly;

- deduplicate the external function definitions. Multiple external function definitions were emitted if a new composition function is referred by using its python function.

Unit tests are updated.

PiperOrigin-RevId: 380229692
Change-Id: Icffde90dafb4a0dd2df213410c9cf1d8317464fd
  • Loading branch information
liufengdb authored and tensorflower-gardener committed Jun 18, 2021
1 parent a94c43a commit 0d9d692
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
21 changes: 15 additions & 6 deletions tensorflow/compiler/mlir/tfr/python/tfr_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def lookup(self, f_name, func_def=None, optional=False):
return (op_def, derived_attrs)

def mlir_external_funcs(self):
tfr_funcs = []
tfr_funcs = set()
for _, (op_def, derived_attrs) in sorted(self._op_defs.items()):
tfr_func = '\ntfr.func @tf__{}_('.format(_camel_to_snake(op_def.name))

Expand All @@ -297,7 +297,7 @@ def mlir_external_funcs(self):
attrs_with_default = [
attr for attr in non_derived_attrs if attr.HasField('default_value')
]
attr_names = set()
attr_names = {'f32_', 'i32_', 'i64_', 'i1_'} # reserved
for attr_def in attrs_no_default + attrs_with_default:
inputs.append(_get_type_info_from_proto(None, attr_def))
attr_names.add(attr_def.name)
Expand All @@ -310,9 +310,9 @@ def mlir_external_funcs(self):
inputs = ','.join(inputs)
outputs = ','.join(outputs)
attrs = ','.join(sorted(derived_attrs.union(attr_names)))
tfr_funcs.append('{}{}) -> ({}) attributes {{{}}}'.format(
tfr_funcs.add('{}{}) -> ({}) attributes {{{}}}'.format(
tfr_func, inputs, outputs, attrs))
return tfr_funcs
return sorted(list(tfr_funcs))


_PY_TYPE_TO_TFR = {
Expand Down Expand Up @@ -1506,9 +1506,9 @@ def tfr_gen(func, op_defs):
return mlir_code


def tfr_gen_from_module(source, method_prefix=None, op_libraries=None):
def tfr_funcs_gen_from_module(source, op_defs, method_prefix=None,
op_libraries=None):
"""Parse the input source module and emit the TFR functions."""
op_defs = OpDefCache()

# Load the op library so the op is added to the op registry. This is
# required when the op cc_library couldn't be statically linked in open
Expand Down Expand Up @@ -1547,4 +1547,13 @@ def tfr_gen_from_module(source, method_prefix=None, op_libraries=None):
py_funcs = sorted(py_funcs, key=lambda x: x.__code__.co_firstlineno)
mlir_funcs = [tfr_gen(func, op_defs) for func in py_funcs]

return mlir_funcs


def tfr_gen_from_module(source, method_prefix=None, op_libraries=None,
op_defs=OpDefCache()):
"""Parse the input source module and emit the TFR and external functions."""
mlir_funcs = tfr_funcs_gen_from_module(
source, op_defs, method_prefix, op_libraries)

return '\n'.join(mlir_funcs + op_defs.mlir_external_funcs())
20 changes: 10 additions & 10 deletions tensorflow/compiler/mlir/tfr/python/tfr_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,25 +507,25 @@ def test_tfr_tf_ops(self):
CHECK-NEXT: tfr.return %[[call]] : !tfr.tensor
CHECK-NEXT: }
CHECK-LABEL: tfr.func @tf__add_(!tfr.tensor<T>,!tfr.tensor<T>) -> (!tfr.tensor<T>) attributes {T}
CHECK-LABEL: tfr.func @tf__add_(!tfr.tensor<T>,!tfr.tensor<T>) -> (!tfr.tensor<T>) attributes {T,f32_,i1_,i32_,i64_}
CHECK-LABEL: tfr.func @tf__concat_(!tfr.tensor<i32_>,!tfr.tensor_list<N,T>) -> (!tfr.tensor<T>) attributes {N,T,i32_}
CHECK-LABEL: tfr.func @tf__concat_(!tfr.tensor<i32_>,!tfr.tensor_list<N,T>) -> (!tfr.tensor<T>) attributes {N,T,f32_,i1_,i32_,i64_}
CHECK-LABEL: tfr.func @tf__identity_(!tfr.tensor<T>) -> (!tfr.tensor<T>) attributes {T}
CHECK-LABEL: tfr.func @tf__identity_(!tfr.tensor<T>) -> (!tfr.tensor<T>) attributes {T,f32_,i1_,i32_,i64_}
CHECK-LABEL: tfr.func @tf__pack_(!tfr.tensor_list<N,T>,i64{tfr.name="axis",tfr.type="int"}) -> (!tfr.tensor<T>) attributes {N,T,axis}
CHECK-LABEL: tfr.func @tf__pack_(!tfr.tensor_list<N,T>,i64{tfr.name="axis",tfr.type="int"}) -> (!tfr.tensor<T>) attributes {N,T,axis,f32_,i1_,i32_,i64_}
CHECK-LABEL: tfr.func @tf__split_v_(!tfr.tensor<T>,!tfr.tensor<Tlen>,!tfr.tensor<i32_>,i64{tfr.name="num_split",tfr.type="int"}) -> (!tfr.tensor_list<num_split,T>) attributes {T,Tlen,i32_,num_split}
CHECK-LABEL: tfr.func @tf__split_v_(!tfr.tensor<T>,!tfr.tensor<Tlen>,!tfr.tensor<i32_>,i64{tfr.name="num_split",tfr.type="int"}) -> (!tfr.tensor_list<num_split,T>) attributes {T,Tlen,f32_,i1_,i32_,i64_,num_split}
CHECK-LABEL: tfr.func @tf__test_two_inputs_op_(!tfr.tensor<T>,!tfr.tensor<T>,i1{tfr.name="pred",tfr.type="bool"}) -> (!tfr.tensor<T>) attributes {T,pred}
CHECK-LABEL: tfr.func @tf__test_complex_tf_op_(!tfr.tensor<T>,!tfr.tensor<Tlen>,i64{tfr.name="N",tfr.type="int"}) -> (!tfr.tensor_list<N,T>) attributes {N,T,Tlen,f32_,i1_,i32_,i64_}
CHECK-LABEL: tfr.func @tf__test_complex_tf_op_(!tfr.tensor<T>,!tfr.tensor<Tlen>,i64{tfr.name="N",tfr.type="int"}) -> (!tfr.tensor_list<N,T>) attributes {N,T,Tlen}
CHECK-LABEL: tfr.func @tf__test_identity_op_(!tfr.tensor<T>) -> (!tfr.tensor<T>) attributes {T,f32_,i1_,i32_,i64_}
CHECK-LABEL: tfr.func @tf__test_identity_op_(!tfr.tensor<T>) -> (!tfr.tensor<T>) attributes {T}
CHECK-LABEL: tfr.func @tf__test_input_n_op_(!tfr.tensor_list<N,T>) -> (!tfr.tensor<T>) attributes {N,T,f32_,i1_,i32_,i64_}
CHECK-LABEL: tfr.func @tf__test_two_inputs_op_(!tfr.tensor<T>,!tfr.tensor<T>,i1{tfr.name="pred",tfr.type="bool"}) -> (!tfr.tensor<T>) attributes {T,pred}
CHECK-LABEL: tfr.func @tf__test_two_inputs_op_(!tfr.tensor<T>,!tfr.tensor<T>,i1{tfr.name="pred",tfr.type="bool"}) -> (!tfr.tensor<T>) attributes {T,f32_,i1_,i32_,i64_,pred}
CHECK-LABEL: tfr.func @tf__test_input_n_op_(!tfr.tensor_list<N,T>) -> (!tfr.tensor<T>) attributes {N,T}
CHECK-LABEL: tfr.func @tf__test_two_outputs_op_(!tfr.tensor<T>) -> (!tfr.tensor<T>,!tfr.tensor<T>) attributes {T,f32_,i1_,i32_,i64_}
"""
self._check_code(mlir_code, mlir_code_exp)

Expand Down

0 comments on commit 0d9d692

Please sign in to comment.