diff --git a/python/nnvm/frontend/onnx.py b/python/nnvm/frontend/onnx.py index 667c1787b..fcd15db1f 100644 --- a/python/nnvm/frontend/onnx.py +++ b/python/nnvm/frontend/onnx.py @@ -3,6 +3,8 @@ from __future__ import absolute_import as _abs import tvm from .. import symbol as _sym +from .. import graph as _graph +from .. compiler import graph_util from .common import Renamer, AttrConverter as AttrCvt __all__ = ['from_onnx'] @@ -60,9 +62,9 @@ def _pooling(name): 'kernel_shape': 'pool_size', 'pads': ('padding', (0, 0), _revert_caffe2_pad)}, # very weird attributes here in onnx, force check - excludes=['dilations'], + ignores=['dilations'], # TODO(zhreshold): make sure ceil_mode in onnx, and layout? - extras={'ceil_mode': True}, + extras={'ceil_mode': False}, custom_check=_dimension_constraint()) def _conv(): @@ -90,7 +92,7 @@ def _batch_norm(): return AttrCvt( op_name='batch_norm', disables=['momentum'], - ignores=['spatial', 'is_test']) + ignores=['spatial', 'is_test', 'consumed_inputs']) # compatible operators that do NOT require any conversion. @@ -100,6 +102,7 @@ def _batch_norm(): _convert_map = { # defs/experimental 'FC' : AttrCvt('dense', ignores=['axis', 'axis_w']), + 'SpatialBN' : _batch_norm(), # defs/generator # 'Constant' @@ -200,7 +203,7 @@ def _convert_operator(op_name, attrs, identity_list=None, convert_map=None): elif op_name in convert_map: op_name, attrs = convert_map[op_name](attrs) else: - _raise_not_supported('Operator: ' + op_name) + raise NotImplementedError("Operator {} not implemented.".format(op_name)) op = getattr(_sym, op_name, None) if not op: raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name)) @@ -267,10 +270,11 @@ def from_onnx(self, graph): new_attr = self._fix_channels(new_op, new_attr, list(node.input)) self._fix_bias_shape(node.op_type, graph.node[idx-1].op_type, node.input) op = new_op(name=node_name, *inputs, **new_attr) - assert len(node.output) == len(op.list_output_names()), ( - "Number of output mismatch {} vs {}.".format( - len(node.output), len(op.list_output_names()))) - for k, i in zip(list(node.output), range(len(node.output))): + node_output = self._fix_outputs(op_name, node.output) + assert len(node_output) == len(op.list_output_names()), ( + "Number of output mismatch {} vs {} in {}.".format( + len(node_output), len(op.list_output_names()), op_name)) + for k, i in zip(list(node_output), range(len(node_output))): self._nodes[k] = op[i] # now return the outputs out = [self._nodes[i] for i in graph.output] @@ -310,6 +314,15 @@ def _parse_attr(self, attr_proto): raise ValueError("Cannot parse attribute: \n{}\n.".format(a)) return attrs + def _fix_outputs(self, op, outputs): + """A hack to handle dropout or similar operator that have more than one out + in ONNX. + """ + if op == 'Dropout': + assert len(outputs) == 2, "ONNX have two outputs for dropout layer." + outputs = outputs[:-1] + return outputs + def _fix_bias(self, op, attrs, num_inputs): """A hack for 'use_bias' attribute since onnx don't provide this attribute, we have to check the number of inputs to decide it.""" @@ -340,17 +353,24 @@ def _fix_channels(self, op, attrs, inputs): """ if op not in [_sym.conv2d, _sym.conv2d_transpose, _sym.dense]: return attrs - weight_name = self._renames[inputs[1]] - if not weight_name in self._params: - raise ValueError("Unable to get channels/units attr from onnx graph.") + if inputs[1] not in self._renames: + assert inputs[1] in self._nodes + g = _graph.create(self._nodes[inputs[1]]) + shape_dict = {k: v.shape for k, v in self._params.items()} + _, out_shapes = graph_util.infer_shape(g, **shape_dict) + channels = out_shapes[0][0] else: - wshape = self._params[weight_name].shape - assert len(wshape) >= 2, "Weights shape is invalid: {}".format(wshape) - channels = wshape[0] - if op in [_sym.dense]: - attrs['units'] = channels + weight_name = self._renames[inputs[1]] + if not weight_name in self._params: + raise ValueError("Unable to get channels/units attr from onnx graph.") else: - attrs['channels'] = channels + wshape = self._params[weight_name].shape + assert len(wshape) >= 2, "Weights shape is invalid: {}".format(wshape) + channels = wshape[0] + if op in [_sym.dense]: + attrs['units'] = channels + else: + attrs['channels'] = channels return attrs def from_onnx(graph):