-
Notifications
You must be signed in to change notification settings - Fork 12
/
shape_infer_helper.py
87 lines (74 loc) · 3.64 KB
/
shape_infer_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#--------------------------------------------------------------------------
import os
import sys
import onnx
# In ORT Package the symbolic_shape_infer.py is in ../tools
file_path = os.path.dirname(__file__)
if os.path.exists(os.path.join(file_path, "../tools/symbolic_shape_infer.py")):
sys.path.append(os.path.join(file_path, '../tools'))
else:
sys.path.append(os.path.join(file_path, '..'))
from symbolic_shape_infer import SymbolicShapeInference, get_shape_from_type_proto, sympy
class SymbolicShapeInferenceHelper(SymbolicShapeInference):
def __init__(self, model, verbose=0, int_max=2**31 - 1, auto_merge=True, guess_output_rank=False):
super().__init__(int_max, auto_merge, guess_output_rank, verbose)
self.model_ = onnx.ModelProto()
self.model_.CopyFrom(model)
self.all_shapes_inferred_ = False
self.inferred_ = False
# The goal is to remove dynamic_axis_mapping
def infer(self, dynamic_axis_mapping):
if self.inferred_:
return self.all_shapes_inferred_
self.dynamic_axis_mapping_ = dynamic_axis_mapping # e.g {"batch_size" : 4, "seq_len" :7}
self._preprocess(self.model_)
while self.run_:
self.all_shapes_inferred_ = self._infer_impl()
self.inferred_ = True
return self.all_shapes_inferred_
# override _preprocess() to avoid unnecessary model copy since ctor copies the model
def _preprocess(self, in_mp):
self.out_mp_ = in_mp
self.graph_inputs_ = dict([(i.name, i) for i in list(self.out_mp_.graph.input)])
self.initializers_ = dict([(i.name, i) for i in self.out_mp_.graph.initializer])
self.known_vi_ = dict([(i.name, i) for i in list(self.out_mp_.graph.input)])
self.known_vi_.update(
dict([(i.name, onnx.helper.make_tensor_value_info(i.name, i.data_type, list(i.dims)))
for i in self.out_mp_.graph.initializer]))
# Override _get_sympy_shape() in symbolic_shape_infer.py to ensure shape inference by giving the actual value of dynamic axis
def _get_sympy_shape(self, node, idx):
sympy_shape = []
for d in self._get_shape(node, idx):
if type(d) == str:
if d in self.dynamic_axis_mapping_.keys():
sympy_shape.append(self.dynamic_axis_mapping_[d])
elif d in self.symbolic_dims_:
sympy_shape.append(self.symbolic_dims_[d])
else:
sympy_shape.append(sympy.Symbol(d, integer=True))
else:
assert None != d
sympy_shape.append(d)
return sympy_shape
def get_edge_shape(self, edge):
assert (self.all_shapes_inferred_ == True)
if edge not in self.known_vi_:
print("Cannot retrive the shape of " + str(edge))
return None
type_proto = self.known_vi_[edge].type
shape = get_shape_from_type_proto(type_proto)
for i in range(len(shape)):
d = shape[i]
if type(d) == str and d in self.dynamic_axis_mapping_.keys():
shape[i] = self.dynamic_axis_mapping_[d]
return shape
def compare_shape(self, edge, edge_other):
assert (self.all_shapes_inferred_ == True)
shape = self.get_edge_shape(edge)
shape_other = self.get_edge_shape(edge_other)
if shape is None or shape_other is None:
raise Exception("At least one shape is missed for edges to compare")
return shape == shape_other