diff --git a/onnx/backend/test/case/node/scatternd.py b/onnx/backend/test/case/node/scatternd.py index 86c292900be..8fb80527689 100644 --- a/onnx/backend/test/case/node/scatternd.py +++ b/onnx/backend/test/case/node/scatternd.py @@ -19,15 +19,15 @@ def scatter_nd_impl(data, indices, updates, reduction="none"): # type: ignore for i in np.ndindex(indices.shape[:-1]): # NOTE: The order of iteration in this loop is not specified. if reduction == "add": - output[indices[i]] += updates[i] + output[tuple(indices[i])] += updates[i] elif reduction == "mul": - output[indices[i]] *= updates[i] + output[tuple(indices[i])] *= updates[i] elif reduction == "max": - output[indices[i]] = np.maximum(output[indices[i]], updates[i]) + output[tuple(indices[i])] = np.maximum(output[indices[i]], updates[i]) elif reduction == "min": - output[indices[i]] = np.minimum(output[indices[i]], updates[i]) + output[tuple(indices[i])] = np.minimum(output[indices[i]], updates[i]) else: - output[indices[i]] = updates[i] + output[tuple(indices[i])] = updates[i] return output diff --git a/onnx/reference/ops/op_scatternd.py b/onnx/reference/ops/op_scatternd.py index 8d64e23e447..e9f88c5596f 100644 --- a/onnx/reference/ops/op_scatternd.py +++ b/onnx/reference/ops/op_scatternd.py @@ -12,15 +12,15 @@ def _scatter_nd_impl(data, indices, updates, reduction=None): # type: ignore output = np.copy(data) for i in np.ndindex(indices.shape[:-1]): if reduction == "add": - output[indices[i]] += updates[i] + output[tuple(indices[i])] += updates[i] elif reduction == "mul": - output[indices[i]] *= updates[i] + output[tuple(indices[i])] *= updates[i] elif reduction == "max": - output[indices[i]] = np.maximum(output[indices[i]], updates[i]) + output[tuple(indices[i])] = np.maximum(output[indices[i]], updates[i]) elif reduction == "min": - output[indices[i]] = np.minimum(output[indices[i]], updates[i]) + output[tuple(indices[i])] = np.minimum(output[indices[i]], updates[i]) else: - output[indices[i]] = updates[i] + output[tuple(indices[i])] = updates[i] return output diff --git a/onnx/test/reference_evaluator_test.py b/onnx/test/reference_evaluator_test.py index 46d2726e40f..4e9e65abeed 100644 --- a/onnx/test/reference_evaluator_test.py +++ b/onnx/test/reference_evaluator_test.py @@ -2075,6 +2075,30 @@ def test_scatter_elements(self): expected = np.array([[1.0, 1.1, 3.0, 4.0, 5.0]], dtype=np.float32) assert_allclose(expected, got1[0]) + def test_scatternd(self): + X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) + Ind = make_tensor_value_info("I", TensorProto.INT64, [None, None]) + U = make_tensor_value_info("U", TensorProto.FLOAT, [None]) + Y = make_tensor_value_info("Y", TensorProto.FLOAT, [None, None]) + + node = make_node( + "ScatterND", + ["X", "I", "U"], + ["Y"], + ) + graph = make_graph([node], "g", [X, Ind, U], [Y]) + onnx_model = make_model(graph, opset_imports=[make_opsetid("", 16)]) + feeds = { + "X": np.array([[1.0, 2.0]], dtype=np.float32), + "I": np.array([[0, 0]]), + "U": np.array([3.0], dtype=np.float32), + } + + ref1 = ReferenceEvaluator(onnx_model) + got1 = ref1.run(None, feeds) + expected = np.array([[3.0, 2.0]], dtype=np.float32) + assert_allclose(expected, got1[0]) + def test_col2im_impl(self): def get_im2col_indices( x_shape, field_height, field_width, padding=None, stride=1