Skip to content

Commit

Permalink
Fix linked parameters in reference implementation for Concat (onnx#4797)
Browse files Browse the repository at this point in the history
Signed-off-by: xadupre <[email protected]>
  • Loading branch information
xadupre authored Feb 2, 2023
1 parent 9a5587d commit 91125d1
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 5 deletions.
9 changes: 4 additions & 5 deletions onnx/reference/ops/op_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@


class Concat(OpRun):
def _preprocess(self, a: np.ndarray) -> np.ndarray:
def _preprocess(self, a: np.ndarray, axis: int) -> np.ndarray:
if len(a.shape) == 0:
raise RuntimeError(f"Concat: one input has an empty shape: {a!r}.")
if self.axis >= len(a.shape): # type: ignore
new_shape = a.shape + (1,) * (self.axis + 1 - len(a.shape)) # type: ignore
if axis >= len(a.shape): # type: ignore
new_shape = a.shape + (1,) * (axis + 1 - len(a.shape)) # type: ignore
return a.reshape(new_shape)
return a

def _run(self, *args, axis=None): # type: ignore
axis = axis or self.axis # type: ignore
targs = tuple(self._preprocess(a) for a in args)
targs = tuple(self._preprocess(a, axis) for a in args)
return (np.concatenate(targs, axis),) # type: ignore
61 changes: 61 additions & 0 deletions onnx/test/reference_evaluator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2814,6 +2814,67 @@ def test_mvn(self, opset: int, ref_opset: int = 13):
self.assertEqual(expected.shape, got.shape)
assert_allclose(expected, got)

def test_concat_in_a_function(self):
def create_model():
nodes = []
inputs = []
outputs = []
functions = []

opsets = {"": onnx_opset_version(), "custom_domain": 1}
nodes_fct = []
node = make_node("Concat", ["x:0", "x:1"], ["r__0"], axis=0, domain="")
nodes_fct.append(node)

opset_imports_fct = [
make_opsetid(domain, 1 if version is None else version)
for domain, version in opsets.items()
]
fct = make_function(
"custom_domain",
"concat_2",
["x:0", "x:1"],
["r__0"],
nodes_fct,
opset_imports_fct,
)
functions.append(fct)

inputs.append(make_tensor_value_info("I__0", TensorProto.DOUBLE, []))
inputs.append(make_tensor_value_info("I__1", TensorProto.DOUBLE, []))
inputs.append(make_tensor_value_info("I__2", TensorProto.DOUBLE, []))
outputs.append(make_tensor_value_info("r__4", TensorProto.DOUBLE, []))

node = make_node(
"concat_2", ["I__0", "I__1"], ["r__3"], axis=0, domain="custom_domain"
)
nodes.append(node)
node = make_node(
"concat_2", ["I__2", "r__3"], ["r__4"], axis=0, domain="custom_domain"
)
nodes.append(node)
opset_imports = [
make_opsetid(domain, 1 if version is None else version)
for domain, version in opsets.items()
]

graph = make_graph(nodes, "numpyx", inputs, outputs)

onnx_model = make_model(
graph, opset_imports=opset_imports, functions=functions
)
return onnx_model

onnx_model = create_model()
x1 = np.array([[-5, 6], [15, 3]], dtype=np.float64)
x2 = np.array([[1, 2]], dtype=np.float64)
x3 = np.array([[-1, -2]], dtype=np.float64)
z = np.vstack([x1, x2, x3])
ref = ReferenceEvaluator(onnx_model)
feeds = {"I__2": x1, "I__0": x2, "I__1": x3}
got = ref.run(None, feeds)
assert_allclose(z, got[0])

def test_cast_float_to_string(self):
X = make_tensor_value_info("X", TensorProto.FLOAT, [None])
Y = make_tensor_value_info("Y", TensorProto.STRING, [None])
Expand Down

0 comments on commit 91125d1

Please sign in to comment.