diff --git a/runtime/bindings/python/src/compatibility/ngraph/__init__.py b/runtime/bindings/python/src/compatibility/ngraph/__init__.py index 8b12a3c7ff9d82..662134c0f48045 100644 --- a/runtime/bindings/python/src/compatibility/ngraph/__init__.py +++ b/runtime/bindings/python/src/compatibility/ngraph/__init__.py @@ -166,6 +166,7 @@ from ngraph.opset8 import sign from ngraph.opset8 import sin from ngraph.opset8 import sinh +from ngraph.opset8 import slice from ngraph.opset8 import softmax from ngraph.opset8 import softplus from ngraph.opset8 import space_to_batch diff --git a/runtime/bindings/python/src/compatibility/ngraph/opset8/__init__.py b/runtime/bindings/python/src/compatibility/ngraph/opset8/__init__.py index f0d0dfdd2dbf64..74029d8869ef64 100644 --- a/runtime/bindings/python/src/compatibility/ngraph/opset8/__init__.py +++ b/runtime/bindings/python/src/compatibility/ngraph/opset8/__init__.py @@ -55,7 +55,7 @@ from ngraph.opset1.ops import floor_mod from ngraph.opset8.ops import gather from ngraph.opset6.ops import gather_elements -from ngraph.opset5.ops import gather_nd +from ngraph.opset8.ops import gather_nd from ngraph.opset1.ops import gather_tree from ngraph.opset7.ops import gelu from ngraph.opset1.ops import greater @@ -140,6 +140,7 @@ from ngraph.opset1.ops import sign from ngraph.opset1.ops import sin from ngraph.opset1.ops import sinh +from ngraph.opset8.ops import slice from ngraph.opset1.ops import softmax from ngraph.opset4.ops import softplus from ngraph.opset2.ops import space_to_batch diff --git a/runtime/bindings/python/src/compatibility/ngraph/opset8/ops.py b/runtime/bindings/python/src/compatibility/ngraph/opset8/ops.py index 6c355930b7c021..fdf71ea6f86c1e 100644 --- a/runtime/bindings/python/src/compatibility/ngraph/opset8/ops.py +++ b/runtime/bindings/python/src/compatibility/ngraph/opset8/ops.py @@ -367,3 +367,53 @@ def random_uniform( "op_seed": op_seed, } return _get_node_factory_opset8().create("RandomUniform", inputs, attributes) + + +@nameable_op +def slice( + data: NodeInput, + start: NodeInput, + stop: NodeInput, + step: NodeInput, + axes: Optional[NodeInput] = None, + name: Optional[str] = None, +) -> Node: + """Return a node which generates Slice operation. + + @param data: The node providing input data. + @param start: The node providing start indices (inclusively). + @param stop: The node providing stop indices (exclusively). + @param step: The node providing step values. + @param axes: The optional node providing axes to slice, default [0, 1, ..., len(start)-1]. + @param name: The optional name for the created output node. + @return The new node performing Slice operation. + """ + if axes is None: + inputs = as_nodes(data, start, stop, step) + else: + inputs = as_nodes(data, start, stop, step, axes) + + return _get_node_factory_opset8().create("Slice", inputs) + + +@nameable_op +def gather_nd( + data: NodeInput, + indices: NodeInput, + batch_dims: Optional[int] = 0, + name: Optional[str] = None, +) -> Node: + """Return a node which performs GatherND. + + @param data: N-D tensor with data for gathering + @param indices: K-D tensor of tuples with indices by which data is gathered + @param batch_dims: Scalar value of batch dimensions + @return: The new node which performs GatherND + """ + inputs = as_nodes(data, indices) + + attributes = { + "batch_dims": batch_dims + } + + return _get_node_factory_opset8().create("GatherND", inputs, attributes) diff --git a/runtime/bindings/python/src/openvino/opset8/ops.py b/runtime/bindings/python/src/openvino/opset8/ops.py index ef29f36cf083c0..67559a8dc49dab 100644 --- a/runtime/bindings/python/src/openvino/opset8/ops.py +++ b/runtime/bindings/python/src/openvino/opset8/ops.py @@ -375,7 +375,8 @@ def slice( start: NodeInput, stop: NodeInput, step: NodeInput, - axes: NodeInput = None + axes: Optional[NodeInput] = None, + name: Optional[str] = None, ) -> Node: """Return a node which generates Slice operation. @@ -384,6 +385,8 @@ def slice( @param stop: The node providing stop indices (exclusively). @param step: The node providing step values. @param axes: The optional node providing axes to slice, default [0, 1, ..., len(start)-1]. + @param name: The optional name for the created output node. + @return The new node performing Slice operation. """ if axes is None: inputs = as_nodes(data, start, stop, step) diff --git a/runtime/bindings/python/tests_compatibility/test_ngraph/test_create_op.py b/runtime/bindings/python/tests_compatibility/test_ngraph/test_create_op.py index 673d7a2ebf10b4..c5d97de1753877 100644 --- a/runtime/bindings/python/tests_compatibility/test_ngraph/test_create_op.py +++ b/runtime/bindings/python/tests_compatibility/test_ngraph/test_create_op.py @@ -1923,3 +1923,31 @@ def test_matrix_nms(): assert nms_node.get_output_element_type(0) == Type.f32 assert nms_node.get_output_element_type(1) == Type.i32 assert nms_node.get_output_element_type(2) == Type.i32 + + +def test_slice(): + data_shape = [10, 7, 2, 13] + data = ng.parameter(data_shape, name="input", dtype=np.float32) + + start = ng.constant(np.array([2, 0, 0], dtype=np.int32)) + stop = ng.constant(np.array([9, 7, 2], dtype=np.int32)) + step = ng.constant(np.array([2, 1, 1], dtype=np.int32)) + + node_default_axes = ng.slice(data, start, stop, step) + + assert node_default_axes.get_type_name() == "Slice" + assert node_default_axes.get_output_size() == 1 + assert node_default_axes.get_output_element_type(0) == Type.f32 + assert tuple(node_default_axes.get_output_shape(0)) == np.zeros(data_shape)[2:9:2, ::, 0:2:1].shape + + start = ng.constant(np.array([0, 2], dtype=np.int32)) + stop = ng.constant(np.array([2, 9], dtype=np.int32)) + step = ng.constant(np.array([1, 2], dtype=np.int32)) + axes = ng.constant(np.array([-2, 0], dtype=np.int32)) + + node = ng.slice(data, start, stop, step, axes) + + assert node.get_type_name() == "Slice" + assert node.get_output_size() == 1 + assert node.get_output_element_type(0) == Type.f32 + assert tuple(node.get_output_shape(0)) == np.zeros(data_shape)[2:9:2, ::, 0:2:1].shape diff --git a/runtime/bindings/python/tests_compatibility/test_ngraph/test_data_movement.py b/runtime/bindings/python/tests_compatibility/test_ngraph/test_data_movement.py index 5873057f67957c..7f0ff39c9c4bfc 100644 --- a/runtime/bindings/python/tests_compatibility/test_ngraph/test_data_movement.py +++ b/runtime/bindings/python/tests_compatibility/test_ngraph/test_data_movement.py @@ -196,6 +196,21 @@ def test_gather_nd(): batch_dims = 2 expected_shape = [20, 30, 40, 50] + node = ng.opset5.gather_nd(data, indices, batch_dims) + assert node.get_type_name() == "GatherND" + assert node.get_output_size() == 1 + assert list(node.get_output_shape(0)) == expected_shape + assert node.get_output_element_type(0) == Type.f32 + + +def test_gather_v8_nd(): + indices_type = np.int32 + data_dtype = np.float32 + data = ng.parameter([2, 10, 80, 30, 50], dtype=data_dtype, name="data") + indices = ng.parameter([2, 10, 30, 40, 2], dtype=indices_type, name="indices") + batch_dims = 2 + expected_shape = [2, 10, 30, 40, 50] + node = ng.gather_nd(data, indices, batch_dims) assert node.get_type_name() == "GatherND" assert node.get_output_size() == 1