forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconstant_folding.py
298 lines (237 loc) · 10.6 KB
/
constant_folding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
# mypy: allow-untyped-defs
import collections
from collections import defaultdict
from typing import Any, Callable, Dict, Optional
import torch
import torch.utils._pytree as pytree
aten = torch.ops.aten
# We would like to split modules into two subgraphs for runtime weight updates to work correctly.
# The use case and more information could be found at:
# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing
META_TAG = "MODULE_TYPE"
MODULE_TAG = "_MAIN_MODULE"
CONST_MODULE_TAG = "_CONST_MODULE"
def replace_node_with_constant(gm, node, constant, name=None):
g = gm.graph
if name:
qualname = name
else:
if not hasattr(gm, "_frozen_param_count"):
gm._frozen_param_count = 0
i = gm._frozen_param_count
while True:
qualname = f"_frozen_param{i}"
if not hasattr(gm, qualname):
break
i += 1
gm._frozen_param_count = i + 1
with g.inserting_before(node):
new_input_node = g.create_node("get_attr", qualname, (), {})
node.replace_all_uses_with(new_input_node)
new_input_node.meta.update(node.meta)
g.erase_node(node)
# needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
gm.register_buffer(qualname, constant)
setattr(gm, qualname, constant)
class ConstantFolder(torch.fx.Interpreter):
def __init__(
self,
gm,
skip_constructors=False,
):
super().__init__(gm)
self.node_replacements: Dict[torch.fx.Node, Any] = {}
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
self.unknown_value = object()
self.skip_constructors: bool = skip_constructors
# overwrite this to deallocate env values if their only remaining use
# is the output
self.user_to_last_uses = self.node_to_last_non_output_use()
def is_impure(self, node: torch.fx.node.Node):
if (
node.target == torch.ops.prims.convert_element_type.default
and node.args[0].op == "get_attr" # type: ignore[union-attr]
and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
and node.args[1] == torch.bfloat16
):
# For int8_weight -> dq -> bf16_weight
return True
if node.target in [
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
]:
# For the pattern fp32_weight -> q -> dq
# We only folding fp32_weight -> q
# int8_weight and leave dq in graph to be fused
return True
return False
def node_to_last_non_output_use(self):
last_non_output_use = collections.defaultdict(list)
seen_uses = set()
output_node = next(iter(reversed(self.module.graph.nodes)))
for node in reversed(self.module.graph.nodes):
if node.target == "output":
continue
def add_use(inp):
if inp in seen_uses:
return
seen_uses.add(inp)
last_non_output_use[node].append(inp)
# In-place is fine since we don't mutate
pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs))
# if this node is only used in output, we want to gc it right away
if len(node.users) == 1 and output_node in node.users:
last_non_output_use[node].append(node)
return last_non_output_use
def run_node(self, node):
if node.target == "output":
# because we remove nodes from env on last non output use,
# re-define them now or we'll get error in interpreter
def set_env(arg):
self.env[arg] = self.unknown_value
# In-place is fine since we don't mutate
pytree.tree_map_only_(torch.fx.Node, set_env, node.args)
return super().run_node(node)
args, kwargs = self.fetch_args_kwargs_from_env(node)
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
# We need to do this weird thing because in cases where flattened_inputs
# contains a ScriptObject, equality checking results in a type error if
# the types are different.
if any(
type(self.unknown_value) == type(input_) and self.unknown_value == input_
for input_ in flattened_inputs
):
return self.unknown_value
# TODO - fix errors with this
if (
node.op == "call_function"
and node.target == aten._efficientzerotensor.default
):
return self.unknown_value
# TODO - constant folding triton kernel returns the inputs -- fix this
if (
node.op == "call_function"
and node.name == "triton_kernel_wrapper_functional_proxy"
):
return self.unknown_value
# skip constructors, since inductor generates optimal code for them already
# and turning into tensor would result in an additional global memory read
# TODO - more complicated strategy
if (
self.skip_constructors
and node.op != "get_attr"
and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
):
return self.unknown_value
# All mutations should either be removed or on inputs which we did not make constant
if (
isinstance(node.target, torch._ops.OpOverload)
and torch.Tag.nondeterministic_seeded in node.target.tags
):
return self.unknown_value
out = super().run_node(node)
if node.op != "get_attr" and isinstance(out, torch.Tensor):
if out.device.type == "meta":
return out
if not self.insertable_tensor_check(out):
return out
if self.is_impure(node):
return self.unknown_value
self.add_node_replacement(node, out)
flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
for n in flattened_node_inps:
if not isinstance(n, torch.fx.Node):
continue
self.replaced_uses[n] += 1
for to_delete in self.user_to_last_uses.get(node, []):
if self.replaced_uses[to_delete] == len(to_delete.users):
self.node_replacements.pop(to_delete, None)
return out
def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
return True
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
self.node_replacements[node] = tensor
def run(self):
env = {}
for n in self.module.graph.find_nodes(op="placeholder"):
env[n] = self.unknown_value
return super().run(initial_env=env)
@torch.utils._python_dispatch._disable_current_modes()
def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None):
cf = ConstantFolder(gm, skip_constructors=True)
cf.run()
for node, constant in cf.node_replacements.items():
if constraint_fn is not None and not constraint_fn(node):
continue
replace_node_with_constant(gm, node, constant)
erased_params = []
# Get all attr users by looking up the graph instead from node.users, because in this case
# _tensor_constant0 and _tensor_constant0_1 are actually refereing to the same tensor.
# opcode name target args kwargs
# ------------- ------------------- ---------------- --------------------------- --------
# placeholder arg0_1 arg0 () {}
# get_attr _tensor_constant0 state () {}
# call_function add aten.add.Tensor (arg0_1, _tensor_constant0) {}
# get_attr _tensor_constant0_1 state () {}
# call_function add_ aten.add_.Tensor (_tensor_constant0_1, 1) {}
# output output output ([add],) {}
get_attr_node_users = defaultdict(list)
for node in gm.graph.nodes:
if node.op == "get_attr":
get_attr_node_users[node.target].extend(node.users.keys())
for node in gm.graph.find_nodes(op="get_attr"):
if node.op == "get_attr" and len(get_attr_node_users[node.target]) == 0:
if hasattr(gm, node.target):
delattr(gm, node.target)
erased_params.append(node)
for node in erased_params:
gm.graph.erase_node(node)
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
@torch.utils._python_dispatch._disable_current_modes()
def constant_graph_tag(gm: torch.fx.GraphModule):
cf = ConstantFolder(gm, skip_constructors=True)
cf.run()
for node in gm.graph.nodes:
if (
node.op == "get_attr"
or node in cf.node_replacements
or node in cf.replaced_uses
):
node.meta[META_TAG] = CONST_MODULE_TAG
else:
node.meta[META_TAG] = MODULE_TAG
def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Construct a GraphModule which corresponds to the part which could be
constant folded in provided gm.
"""
constant_graph_tag(gm)
# We rewrite the tags, if it's a constant being directly consumed, without
# any folding opportunity, we keep it in main gm.
for node in gm.graph.find_nodes(op="get_attr"):
used_to_fold = False
for u in node.users:
if u.meta[META_TAG] == CONST_MODULE_TAG:
used_to_fold = True
break
if not used_to_fold:
node.meta[META_TAG] = MODULE_TAG
new_graph = torch.fx.Graph()
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
output_nodes = []
for node in gm.graph.nodes:
if node.meta[META_TAG] == MODULE_TAG:
continue
new_node = new_graph.node_copy(node, lambda x: node_remapping[x])
node_remapping[node] = new_node
for user in node.users:
if user.meta[META_TAG] == MODULE_TAG:
output_nodes.append(new_node)
break
new_graph.output(tuple(output_nodes))
new_graph.lint()
new_gm = torch.fx.GraphModule(gm, new_graph)
return new_gm