forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunify_refinements.py
120 lines (102 loc) · 3.05 KB
/
unify_refinements.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from torch.fx.experimental.graph_gradual_typechecker import Refine
from torch.fx.tensor_type import TensorType
from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined]
def infer_symbolic_types_single_pass(traced):
"""
Calls our symbolic inferencer once.
"""
r = Refine(traced)
r.refine()
mgu = unify_eq(r.constraints)
substitute_all_types(traced.graph, mgu)
def infer_symbolic_types(traced):
"""
Calls our symbolic inferencer twice.
This is useful when one pass is not enough
to infer all the information such as the case
for braodcasting.
"""
r = Refine(traced)
r.refine()
mgu = unify_eq(r.constraints)
substitute_all_types(traced.graph, mgu)
r = Refine(traced)
r.refine()
mgu = unify_eq(r.constraints)
substitute_all_types(traced.graph, mgu)
r.symbolic_relations()
def convert_eq(list_of_eq):
"""
Convert equality constraints in the right format
to be used by unification library.
"""
lhs = []
rhs = []
for eq in list_of_eq:
lhs.append(eq.lhs)
rhs.append(eq.rhs)
return tuple(lhs), tuple(rhs)
def unify_eq(list_of_eq):
"""
Apply unification to a set of
equality constraints
"""
lhs, rhs = convert_eq(list_of_eq)
return unify(lhs, rhs)
def substitute_solution_one_type(mapping, t):
"""
Apply the most general unifier to a type
"""
if isinstance(t, Var):
if t in mapping.keys():
return mapping[t]
else:
return t
elif isinstance(t, TensorType):
new_type = []
for typ in t.__args__:
if typ in mapping.keys():
new_type.append(mapping[typ])
else:
new_type.append(typ)
return TensorType(tuple(new_type))
elif isinstance(t, list):
new_type = []
for typ in t:
new_type.append(substitute_solution_one_type(mapping, typ))
return new_type
elif isinstance(t, tuple):
new_type = []
for typ in t:
new_type.append(substitute_solution_one_type(mapping, typ))
return tuple(new_type)
else:
return t
def substitute_all_types(graph, mapping):
"""
Apply the most general unifier to all types in a graph
till reaching a fixed point. If the input and output graph
are the same, we converge.
"""
flag = True
while flag:
flag = False
for k in mapping:
old_mapping_val = mapping[k]
if mapping[k] in mapping.keys():
new_key = mapping[k]
mapping[k] = mapping[new_key]
if old_mapping_val != mapping[k]:
flag = True
for n in graph.nodes:
n.type = substitute_solution_one_type(mapping, n.type)
def check_for_type_equality(g1, g2):
"""
A check equality to be used in fixed points.
We do not use graph equality but instead type
equality.
"""
for n, m in zip(g1.nodes, g2.nodes):
if n.type != m.type:
return False
return True