Skip to content

Commit

Permalink
[Python API] Add missing opest8 ops to compatibility python API (open…
Browse files Browse the repository at this point in the history
  • Loading branch information
mitruska authored Nov 19, 2021
1 parent 7e457bf commit 8399160
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@
from ngraph.opset8 import sign
from ngraph.opset8 import sin
from ngraph.opset8 import sinh
from ngraph.opset8 import slice
from ngraph.opset8 import softmax
from ngraph.opset8 import softplus
from ngraph.opset8 import space_to_batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from ngraph.opset1.ops import floor_mod
from ngraph.opset8.ops import gather
from ngraph.opset6.ops import gather_elements
from ngraph.opset5.ops import gather_nd
from ngraph.opset8.ops import gather_nd
from ngraph.opset1.ops import gather_tree
from ngraph.opset7.ops import gelu
from ngraph.opset1.ops import greater
Expand Down Expand Up @@ -140,6 +140,7 @@
from ngraph.opset1.ops import sign
from ngraph.opset1.ops import sin
from ngraph.opset1.ops import sinh
from ngraph.opset8.ops import slice
from ngraph.opset1.ops import softmax
from ngraph.opset4.ops import softplus
from ngraph.opset2.ops import space_to_batch
Expand Down
50 changes: 50 additions & 0 deletions runtime/bindings/python/src/compatibility/ngraph/opset8/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,53 @@ def random_uniform(
"op_seed": op_seed,
}
return _get_node_factory_opset8().create("RandomUniform", inputs, attributes)


@nameable_op
def slice(
data: NodeInput,
start: NodeInput,
stop: NodeInput,
step: NodeInput,
axes: Optional[NodeInput] = None,
name: Optional[str] = None,
) -> Node:
"""Return a node which generates Slice operation.
@param data: The node providing input data.
@param start: The node providing start indices (inclusively).
@param stop: The node providing stop indices (exclusively).
@param step: The node providing step values.
@param axes: The optional node providing axes to slice, default [0, 1, ..., len(start)-1].
@param name: The optional name for the created output node.
@return The new node performing Slice operation.
"""
if axes is None:
inputs = as_nodes(data, start, stop, step)
else:
inputs = as_nodes(data, start, stop, step, axes)

return _get_node_factory_opset8().create("Slice", inputs)


@nameable_op
def gather_nd(
data: NodeInput,
indices: NodeInput,
batch_dims: Optional[int] = 0,
name: Optional[str] = None,
) -> Node:
"""Return a node which performs GatherND.
@param data: N-D tensor with data for gathering
@param indices: K-D tensor of tuples with indices by which data is gathered
@param batch_dims: Scalar value of batch dimensions
@return: The new node which performs GatherND
"""
inputs = as_nodes(data, indices)

attributes = {
"batch_dims": batch_dims
}

return _get_node_factory_opset8().create("GatherND", inputs, attributes)
5 changes: 4 additions & 1 deletion runtime/bindings/python/src/openvino/opset8/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ def slice(
start: NodeInput,
stop: NodeInput,
step: NodeInput,
axes: NodeInput = None
axes: Optional[NodeInput] = None,
name: Optional[str] = None,
) -> Node:
"""Return a node which generates Slice operation.
Expand All @@ -384,6 +385,8 @@ def slice(
@param stop: The node providing stop indices (exclusively).
@param step: The node providing step values.
@param axes: The optional node providing axes to slice, default [0, 1, ..., len(start)-1].
@param name: The optional name for the created output node.
@return The new node performing Slice operation.
"""
if axes is None:
inputs = as_nodes(data, start, stop, step)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1923,3 +1923,31 @@ def test_matrix_nms():
assert nms_node.get_output_element_type(0) == Type.f32
assert nms_node.get_output_element_type(1) == Type.i32
assert nms_node.get_output_element_type(2) == Type.i32


def test_slice():
data_shape = [10, 7, 2, 13]
data = ng.parameter(data_shape, name="input", dtype=np.float32)

start = ng.constant(np.array([2, 0, 0], dtype=np.int32))
stop = ng.constant(np.array([9, 7, 2], dtype=np.int32))
step = ng.constant(np.array([2, 1, 1], dtype=np.int32))

node_default_axes = ng.slice(data, start, stop, step)

assert node_default_axes.get_type_name() == "Slice"
assert node_default_axes.get_output_size() == 1
assert node_default_axes.get_output_element_type(0) == Type.f32
assert tuple(node_default_axes.get_output_shape(0)) == np.zeros(data_shape)[2:9:2, ::, 0:2:1].shape

start = ng.constant(np.array([0, 2], dtype=np.int32))
stop = ng.constant(np.array([2, 9], dtype=np.int32))
step = ng.constant(np.array([1, 2], dtype=np.int32))
axes = ng.constant(np.array([-2, 0], dtype=np.int32))

node = ng.slice(data, start, stop, step, axes)

assert node.get_type_name() == "Slice"
assert node.get_output_size() == 1
assert node.get_output_element_type(0) == Type.f32
assert tuple(node.get_output_shape(0)) == np.zeros(data_shape)[2:9:2, ::, 0:2:1].shape
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,21 @@ def test_gather_nd():
batch_dims = 2
expected_shape = [20, 30, 40, 50]

node = ng.opset5.gather_nd(data, indices, batch_dims)
assert node.get_type_name() == "GatherND"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == expected_shape
assert node.get_output_element_type(0) == Type.f32


def test_gather_v8_nd():
indices_type = np.int32
data_dtype = np.float32
data = ng.parameter([2, 10, 80, 30, 50], dtype=data_dtype, name="data")
indices = ng.parameter([2, 10, 30, 40, 2], dtype=indices_type, name="indices")
batch_dims = 2
expected_shape = [2, 10, 30, 40, 50]

node = ng.gather_nd(data, indices, batch_dims)
assert node.get_type_name() == "GatherND"
assert node.get_output_size() == 1
Expand Down

0 comments on commit 8399160

Please sign in to comment.