From fb02aec976cf86c7ab64c279ad925ad4a7395f97 Mon Sep 17 00:00:00 2001 From: Anton Chetverikov Date: Mon, 1 Nov 2021 13:52:46 +0300 Subject: [PATCH] [MO]Fix Merge shape infer function (#7864) * Update merge shape infer function * Update merge shape inference tests --- model-optimizer/extensions/ops/merge.py | 16 +++++----- .../unit_tests/extensions/ops/merge_test.py | 30 ++++++++++++++++++- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/model-optimizer/extensions/ops/merge.py b/model-optimizer/extensions/ops/merge.py index 3c21628b254c13..d94e0abe60276b 100644 --- a/model-optimizer/extensions/ops/merge.py +++ b/model-optimizer/extensions/ops/merge.py @@ -34,13 +34,15 @@ def merge_infer(node: Node): inferred_and_executable = [n for n in node.in_nodes().values() if n['is_partial_inferred'] and 'executable' in n and n['executable']] - tensor = inferred_and_executable[0] - - if all([tensor.has_valid('value') and n.has_valid('value') and strict_compare_tensors(tensor.value, n.value) - for n in inferred_and_executable]): - node.out_node().value = tensor.value.copy() - else: - node.out_node().value = None + if len(inferred_and_executable) > 0: + tensor = inferred_and_executable[0] + + if all([tensor.has_valid('value') and n.has_valid('value') and strict_compare_tensors(tensor.value, + n.value) + for n in inferred_and_executable]): + node.out_node().value = tensor.value.copy() + else: + node.out_node().value = None # do not use set_shape(tensor.shape) here because input port shape may be different from the calculated output # shape and `set_shape` will raise an error that shape has changed diff --git a/model-optimizer/unit_tests/extensions/ops/merge_test.py b/model-optimizer/unit_tests/extensions/ops/merge_test.py index ab8fba09eae050..96c1c0d518abaf 100644 --- a/model-optimizer/unit_tests/extensions/ops/merge_test.py +++ b/model-optimizer/unit_tests/extensions/ops/merge_test.py @@ -58,7 +58,7 @@ def test_merge_infer_complex_case(self): edges_with_attrs=self.edges, update_nodes_attributes=[('second', {'executable': True}), ('first', {'is_partial_inferred': False, - 'value': None}), + 'value': None}), ('merge_output', {'shape': np.array([2, 2]), 'value': None}), ('merge', {'is_not_fully_inferred': True})]) @@ -115,3 +115,31 @@ def test_merge_infer_only_second_executable(self): (flag, resp) = compare_graphs(graph, ref_graph, 'merge_output', check_op_attrs=True) self.assertTrue(flag, resp) + + def test_merge_infer_no_executable(self): + graph = build_graph_with_attrs( + nodes_with_attrs=self.nodes, + edges_with_attrs=self.edges, + update_nodes_attributes=[ + ('first', {'executable': False, 'value': np.ones([2, 2]), 'shape': int64_array([2, 2])}), + ('second', {'executable': False, 'value': np.zeros([4, 4]), 'shape': int64_array([4, 4])}) + ] + ) + + ref_graph = build_graph_with_attrs( + nodes_with_attrs=self.nodes, + edges_with_attrs=self.edges, + update_nodes_attributes=[ + ('first', {'executable': False, 'value': np.ones([2, 2]), 'shape': int64_array([2, 2])}), + ('second', {'executable': False, 'value': np.zeros([4, 4]), 'shape': int64_array([4, 4])}), + ('merge', {'is_not_fully_inferred': False}), + ('merge_output', {'shape': int64_array([2, 2]), 'value': None}) + ] + ) + + tested_class = Merge(graph=graph, attrs={}) + node = Node(graph, 'merge') + tested_class.merge_infer(node) + + (flag, resp) = compare_graphs(graph, ref_graph, 'merge_output', check_op_attrs=True) + self.assertTrue(flag, resp)