Skip to content

Commit

Permalink
[MO] fixed dealing with None in then/else branches of Select (openvin…
Browse files Browse the repository at this point in the history
…otoolkit#7465)

* fixed dealing with None values in then/else branches of Select

* generalized solution when condition is non one elemental

* fix when both branches are None

* minor corrections

* rewritten Select unit-tests; fixed for condition with [True] mask

* removed mutable default arg from build_graph, added a few more test cases with masked condition, other minor corrections

* corrected output_shape calculation when broadcasting is off

* layer tests fixed: relaxed assert for condition shape to let pass TF Select

* corrected shape calculation when condition is not elementwise equal, calculated calculation for TF, corrected calculation of shape when values are not set

* fixed a typo for Select from TF Where
  • Loading branch information
pavel-esir authored Oct 15, 2021
1 parent 5b075d8 commit f661133
Show file tree
Hide file tree
Showing 5 changed files with 352 additions and 189 deletions.
77 changes: 57 additions & 20 deletions model-optimizer/extensions/ops/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import numpy as np

from mo.front.common.partial_infer.utils import is_fully_defined
from mo.graph.graph import Node, Graph
from mo.front.common.partial_infer.utils import compatible_shapes, dynamic_dimension, shape_array, is_fully_defined
from mo.graph.graph import Node, Graph, Error
from mo.ops.op import Op
from mo.utils.broadcasting import bi_directional_shape_broadcasting, bi_directional_broadcasting

Expand Down Expand Up @@ -35,31 +35,68 @@ def infer(node: Node):
"Select operation must have 3 inputs: 'condition', 'then' and 'else' tensors for node {}".format(node_name)

condition_value = node.in_port(0).data.get_value()
condition_shape = node.in_port(0).data.get_shape()
resulting_tensors = [node.in_port(1).data.get_value(), node.in_port(2).data.get_value()]

a_shape = node.in_port(1).data.get_shape()
b_shape = node.in_port(2).data.get_shape()
output_shape = bi_directional_shape_broadcasting(a_shape, b_shape)
assert output_shape is not None, 'Input shapes for node {} are not broadcast-able'.format(node_name)
broadcast_rule = node.soft_get('auto_broadcast', 'numpy')

if broadcast_rule == 'numpy':
msg = "In Select node '{}' condition and then/else shapes must be broadcastable. " \
"But instead got: cond_shape={}, then_shape={}, else_shape={}".format(
node_name, condition_shape, a_shape, b_shape)

output_shape = bi_directional_shape_broadcasting(a_shape, b_shape)
assert output_shape is not None, msg

# if Select was created from TF Where operations then 1D condition must have the same size
# as 0-index dimension of output_shape. This condition is different from being numpy compatible
# but by adding ones to the end we can achieve numpy compatibility, as in transformation SelectBroadcast.py
if node.has_valid('format') and node['format'] == 'tf' and len(condition_shape) == 1:
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/array_ops.py#L4596-L4598
msg_tf = "In Select node '{}' if 'condition' is a 1D tensor then it's size " \
"must be matching with the first dimension of then/else branches. " \
"But instead got: cond_shape={}, then_shape={}, else_shape={}".format(
node_name, condition_shape, a_shape, b_shape)

assert condition_shape[0] == output_shape[0], msg_tf
condition_shape = np.concatenate((condition_shape, np.ones(len(output_shape) - 1)))

output_shape = bi_directional_shape_broadcasting(output_shape, condition_shape)
assert output_shape is not None, msg

elif broadcast_rule == 'pdpd':
# todo: add pdpd broadcasting rule
# note that additionally to output_shape resulting_tensors must be broadcasted as well
raise Error("PDPD broadcasting rule is not implemented yet")
else: # broadcasting is not allowed
assert compatible_shapes(a_shape, b_shape) and compatible_shapes(condition_shape, a_shape), \
'In node \'{}\' for Select operation when broadcasting is off all inputs must be of the same shape. ' \
'But instead got: cond_shape={}, then_shape={}, else_shape={}'.format(
node_name, condition_shape, a_shape, b_shape)
output_shape = shape_array([i if i is not dynamic_dimension else j for i, j in zip(a_shape, b_shape)])

node.out_port(0).data.set_shape(output_shape)

if condition_value is not None:
if resulting_tensors[0] is not None:
resulting_tensors[0] = bi_directional_broadcasting(resulting_tensors[0], b_shape)
if resulting_tensors[1] is not None:
resulting_tensors[1] = bi_directional_broadcasting(resulting_tensors[1], a_shape)
condition_value = bi_directional_broadcasting(condition_value, output_shape)

output_value = np.ma.where(condition_value, resulting_tensors[0], resulting_tensors[1])
if condition_value.size != 1:
if np.any(output_value == None):
# If any element of output value is None that means that we use the value from the 'then' or the
# 'else' tensor which is not defined, this means that we cannot perform value propagation.
output_value = None
else:
output_value = output_value.astype(resulting_tensors[not np.bool(condition_value.item(0))].dtype)

if output_value is not None:
if is_fully_defined(condition_value) and np.all(condition_value == condition_value.item(0)):
# in some graphs Select condition is always True[False] and
# one of the branches is None (which is not selected)
# if we use np.where for such cases then dtype of output_value will be object (non numeric type)
# and subsequent numpy operation on such tensors will fail
output_value = resulting_tensors[not np.bool(condition_value.item(0))]
if output_value is None:
return
if broadcast_rule == 'numpy':
output_value = bi_directional_broadcasting(output_value, output_shape)
elif broadcast_rule == 'pdpd':
# todo: add pdpd broadcasting rule
raise Error("PDPD broadcasting rule is not implemented yet")

node.out_port(0).data.set_value(output_value)
elif resulting_tensors[0] is not None and resulting_tensors[1] is not None:
output_value = np.ma.where(condition_value, resulting_tensors[0], resulting_tensors[1])
node.out_port(0).data.set_value(output_value)

@staticmethod
Expand Down
5 changes: 5 additions & 0 deletions model-optimizer/mo/front/common/partial_infer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ def strict_compare_tensors(tensor1, tensor2):
:param tensor2: the second tensor to compare
:return: boolean result of the comparison
"""
if tensor1 is None and tensor2 is None:
return True
if tensor1 is None or tensor2 is None:
return False

if not isinstance(tensor1, np.ma.masked_array):
tensor1 = shape_array(tensor1)
if not isinstance(tensor2, np.ma.masked_array):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def build_graph_to_test_type_alignment(edges,
**shaped_parameter('input_1', input_shape, {'data_type': input_1_type}),
**shaped_parameter('input_2', input_shape, {'data_type': input_2_type}),
**regular_op_with_empty_data('add', {'op': 'Add', 'type': 'Add', 'type_infer': Elementwise.type_infer}),
**valued_const_with_data('const', const_value, {'data_type': const_type}),
**valued_const_with_data('const', const_value, kwargs={'data_type': const_type}),
**result('result'),
}
graph = build_graph(nodes, edges, nodes_with_edges_only=True)
Expand Down
Loading

0 comments on commit f661133

Please sign in to comment.