Skip to content

Commit

Permalink
Fix onnx model maker code gen
Browse files Browse the repository at this point in the history
  • Loading branch information
fumihwh committed May 5, 2021
1 parent 0a793ea commit 85c2ca8
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 14 deletions.
36 changes: 22 additions & 14 deletions onnx_model_maker/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import onnx

TENSOR_PREFIX = "_t_"
HEADER = f'''# Autogenerated by onnx-model-maker. Don't modify it manually.
AUTO_GEN_HEAD = "# Autogenerated by onnx-model-maker. Don't modify it manually."
HEADER = f'''{AUTO_GEN_HEAD}
import onnx
import onnx.helper
Expand All @@ -15,7 +16,9 @@
from onnx_model_maker.ops.op_helper import _add_input
'''

OP_HELPER_PY = f'''from uuid import uuid4
OP_HELPER_PY = f'''{AUTO_GEN_HEAD}
from uuid import uuid4
import numpy
import onnx
Expand All @@ -32,13 +35,20 @@ def _add_input(target, inputs):
inputs.append(t.name)
elif type(target) == str:
inputs.append(target)
elif type(target) == list and all([type(i) == str for i in target]):
inputs.extend(target)
elif type(target) == list:
_add_list(target, inputs)
elif type(target) == onnx.NodeProto:
inputs.append(target.output[0])
def _add_list(target, inputs):
for t in target:
_add_input(t, inputs)
'''

INIT_PY = f'''import glob
INIT_PY = f'''{AUTO_GEN_HEAD}
import glob
import importlib
import os
import sys
Expand Down Expand Up @@ -88,6 +98,10 @@ def Output(*args):
'''

NEW_LINE = '''
'''


def _gen_op_maker(schema):
onnx_op = schema.name
Expand Down Expand Up @@ -155,18 +169,12 @@ def gen(output_dir=None, overwrite=False):
file_contents[str(since_version)].append(_gen_op_maker(schema))
for v, c in file_contents.items():
with open(os.path.join(output_dir, f"op_ver_{v}.py"), "w") as f:
f.write('''
'''.join(c))
f.write(NEW_LINE.join(c))
with open(os.path.join(output_dir, "__init__.py"), "w") as f:
f.write(INIT_PY)
f.write('''
'''.join([abs_op_contents[key] for key in sorted(abs_op_contents.keys())]))
f.write(NEW_LINE.join([abs_op_contents[key] for key in sorted(abs_op_contents.keys())]))
all_str = ', '.join([f'"{key}"' for key in sorted(abs_op_contents.keys())])
f.write(f'''
__all__ = [\"Input\", \"Output\", {all_str}]''')
f.write(f'''{NEW_LINE}__all__ = [\"Input\", \"Output\", {all_str}]''')
with open(os.path.join(output_dir, "op_helper.py"), "w") as f:
f.write(OP_HELPER_PY)

Expand Down
2 changes: 2 additions & 0 deletions onnx_model_maker/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Autogenerated by onnx-model-maker. Don't modify it manually.

import glob
import importlib
import os
Expand Down
2 changes: 2 additions & 0 deletions onnx_model_maker/ops/op_helper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Autogenerated by onnx-model-maker. Don't modify it manually.

from uuid import uuid4

import numpy
Expand Down

0 comments on commit 85c2ca8

Please sign in to comment.