diff --git a/python/nnvm/frontend/onnx.py b/python/nnvm/frontend/onnx.py index fcd15db1f..da6340607 100644 --- a/python/nnvm/frontend/onnx.py +++ b/python/nnvm/frontend/onnx.py @@ -245,18 +245,24 @@ def from_onnx(self, graph): raise ValueError("Tensor's name is required.") self._params[init_tensor.name] = self._parse_array(init_tensor) for i in graph.input: - if i in self._params: + # from onnx v0.2, GraphProto.input has type ValueInfoProto, + # and the name is 'i.name' + try: + i_name = i.name + except AttributeError: + i_name = i + if i_name in self._params: # i is a param instead of input name_param = 'param_{}'.format(self._num_param) self._num_param += 1 - self._params[name_param] = self._params.pop(i) + self._params[name_param] = self._params.pop(i_name) self._nodes[name_param] = _sym.Variable(name=name_param) - self._renames[i] = name_param + self._renames[i_name] = name_param else: name_input = 'input_{}'.format(self._num_input) self._num_input += 1 self._nodes[name_input] = _sym.Variable(name=name_input) - self._renames[i] = name_input + self._renames[i_name] = name_input # construct nodes, nodes are stored as directed acyclic graph for idx, node in enumerate(graph.node): op_name = node.op_type