Skip to content

Commit

Permalink
Fix ScatterND implementation (onnx#5358)
Browse files Browse the repository at this point in the history
### Description
Changed the way the reference implementation implemented ScatterND. The
problem lied in the fact that currently the implementation uses [integer
array
indexing](https://numpy.org/devdocs/user/basics.indexing.html#integer-array-indexing)
instead of multidimensional slicing.

### Motivation and Context
This change solves  onnx#5353.

Signed-off-by: Atanas Dimitrov <[email protected]>
  • Loading branch information
AtanasDimitrovQC authored Jun 26, 2023
1 parent e987130 commit edd695e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 10 deletions.
10 changes: 5 additions & 5 deletions onnx/backend/test/case/node/scatternd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ def scatter_nd_impl(data, indices, updates, reduction="none"): # type: ignore
for i in np.ndindex(indices.shape[:-1]):
# NOTE: The order of iteration in this loop is not specified.
if reduction == "add":
output[indices[i]] += updates[i]
output[tuple(indices[i])] += updates[i]
elif reduction == "mul":
output[indices[i]] *= updates[i]
output[tuple(indices[i])] *= updates[i]
elif reduction == "max":
output[indices[i]] = np.maximum(output[indices[i]], updates[i])
output[tuple(indices[i])] = np.maximum(output[indices[i]], updates[i])
elif reduction == "min":
output[indices[i]] = np.minimum(output[indices[i]], updates[i])
output[tuple(indices[i])] = np.minimum(output[indices[i]], updates[i])
else:
output[indices[i]] = updates[i]
output[tuple(indices[i])] = updates[i]
return output


Expand Down
10 changes: 5 additions & 5 deletions onnx/reference/ops/op_scatternd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ def _scatter_nd_impl(data, indices, updates, reduction=None): # type: ignore
output = np.copy(data)
for i in np.ndindex(indices.shape[:-1]):
if reduction == "add":
output[indices[i]] += updates[i]
output[tuple(indices[i])] += updates[i]
elif reduction == "mul":
output[indices[i]] *= updates[i]
output[tuple(indices[i])] *= updates[i]
elif reduction == "max":
output[indices[i]] = np.maximum(output[indices[i]], updates[i])
output[tuple(indices[i])] = np.maximum(output[indices[i]], updates[i])
elif reduction == "min":
output[indices[i]] = np.minimum(output[indices[i]], updates[i])
output[tuple(indices[i])] = np.minimum(output[indices[i]], updates[i])
else:
output[indices[i]] = updates[i]
output[tuple(indices[i])] = updates[i]
return output


Expand Down
24 changes: 24 additions & 0 deletions onnx/test/reference_evaluator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2075,6 +2075,30 @@ def test_scatter_elements(self):
expected = np.array([[1.0, 1.1, 3.0, 4.0, 5.0]], dtype=np.float32)
assert_allclose(expected, got1[0])

def test_scatternd(self):
X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
Ind = make_tensor_value_info("I", TensorProto.INT64, [None, None])
U = make_tensor_value_info("U", TensorProto.FLOAT, [None])
Y = make_tensor_value_info("Y", TensorProto.FLOAT, [None, None])

node = make_node(
"ScatterND",
["X", "I", "U"],
["Y"],
)
graph = make_graph([node], "g", [X, Ind, U], [Y])
onnx_model = make_model(graph, opset_imports=[make_opsetid("", 16)])
feeds = {
"X": np.array([[1.0, 2.0]], dtype=np.float32),
"I": np.array([[0, 0]]),
"U": np.array([3.0], dtype=np.float32),
}

ref1 = ReferenceEvaluator(onnx_model)
got1 = ref1.run(None, feeds)
expected = np.array([[3.0, 2.0]], dtype=np.float32)
assert_allclose(expected, got1[0])

def test_col2im_impl(self):
def get_im2col_indices(
x_shape, field_height, field_width, padding=None, stride=1
Expand Down

0 comments on commit edd695e

Please sign in to comment.