forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_memory_efficient_fusion.py
371 lines (299 loc) · 11.3 KB
/
test_memory_efficient_fusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
# Owner(s): ["module: functorch"]
import torch
import torch.nn as nn
import torch.fx as fx
from functorch import make_fx
from torch.nn import functional as F
from functorch.compile import memory_efficient_fusion
from torch._functorch.compile_utils import fx_graph_cse
from torch.testing._internal.common_utils import TestCase, run_tests
import inspect
import random
from typing import Callable
import unittest
HAS_CUDA = torch.cuda.is_available()
def _num_args(fn: Callable):
return len(inspect.signature(fn).parameters)
def gelu_bias(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
def swish(x):
return x * torch.sigmoid(x)
def mish(x):
return x.mul(torch.tanh(F.softplus(x)))
def hard_sigmoid(x):
return (x + 3.0).clamp(min=0.0, max=6.0).div(6.0)
def hard_swish(x):
return x * (x + 3.0).clamp(min=0.0, max=6.0).div(6.0)
def hard_mish(x):
return 0.5 * x * (x + 2.0).clamp(min=0.0, max=2.0)
# todo: convert these into tests
# def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False):
# B, C, H, W = x.shape
# x_dtype = x.dtype
# if flatten:
# x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues
# std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype)
# else:
# x = x.reshape(B, groups, C // groups, H, W)
# std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype)
# return std.expand(x.shape).reshape(B, C, H, W)
# class EvoNorm2dS0(nn.Module):
# def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_):
# super().__init__()
# self.apply_act = apply_act # apply activation (non-linearity)
# if group_size:
# assert num_features % group_size == 0
# self.groups = num_features // group_size
# else:
# self.groups = groups
# self.eps = eps
# self.weight = nn.Parameter(torch.ones(num_features))
# self.bias = nn.Parameter(torch.zeros(num_features))
# self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None
# self.reset_parameters()
# def reset_parameters(self):
# nn.init.ones_(self.weight)
# nn.init.zeros_(self.bias)
# if self.v is not None:
# nn.init.ones_(self.v)
# def forward(self, x):
# x_dtype = x.dtype
# v_shape = (1, -1, 1, 1)
# if self.v is not None:
# v = self.v.view(v_shape).to(dtype=x_dtype)
# x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps)
# return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
# device = "cuda"
# dtype = torch.float
# evo_norm = EvoNorm2dS0(2048)
# evo_norm_inp = [(128, 2048, 8, 8)]
def run_and_compare_activation(self, fn, inps):
with torch.jit.fuser("fuser1"):
device = "cuda"
dtype = torch.float
if isinstance(fn, nn.Module):
fn = fn.to(device=device, dtype=dtype)
ref_args = [torch.randn(shape, device=device, dtype=dtype, requires_grad=True) for shape in inps]
res_args = [i.clone().detach().requires_grad_(True) for i in ref_args]
ref = fn(*ref_args)
ref.sum().backward()
mem_optimized_fn = memory_efficient_fusion(fn)
for _ in range(5):
for i in res_args:
i.grad = None
res = mem_optimized_fn(*res_args)
res.sum().backward()
self.assertEqual(ref, res)
for ref_arg, res_arg in zip(ref_args, res_args):
self.assertEqual(ref_arg.grad, res_arg.grad)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
class TestMemoryEfficientOpAuthoring(TestCase):
def test_gelu_bias(self):
run_and_compare_activation(self, gelu_bias, [(1024,), (1024,)])
def test_mish(self):
run_and_compare_activation(self, mish, [(1024,)])
def test_swish(self):
run_and_compare_activation(self, swish, [(1024,)])
def test_hard_sigmoid(self):
run_and_compare_activation(self, hard_sigmoid, [(1024,)])
def test_hard_swish(self):
run_and_compare_activation(self, hard_swish, [(1024,)])
def test_layer_norm(self):
def layer_norm(x, weight, bias):
dim = -1
eps = 1e-5
mean = torch.mean(x, dim, keepdim=True)
centered = x - mean
var = torch.sum(centered * centered, dim, keepdim=True) / x.size(-1)
rvar = 1. / torch.sqrt(var + eps)
normed = (x - mean) * rvar
return normed * weight + bias
bs = 10
ln_size = 16
layer_norm_inps = [(bs, ln_size), (ln_size,), (ln_size,)]
run_and_compare_activation(self, layer_norm, layer_norm_inps)
def test_rmsnorm(self):
class T5LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Construct a layernorm module in the T5 style No bias and no subtraction of mean.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
# layer norm should always be calculated in float32
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
bs = 256
seq = 256
hidden = 1024
t5_norm = T5LayerNorm(hidden)
t5_norm_inputs = [(bs, seq, hidden)]
run_and_compare_activation(self, t5_norm, t5_norm_inputs)
# TODO - Assertion failure
# def test_hard_mish(self):
# for compiler in compilers:
# run_and_compare_activation(hard_mish, 1024)
# check if the CSE modified graph of f has delta less nodes, and do not reduce the number of nodes further on a second pass.
# delta is an integer >= -1. If delta = -1, only check if the new graph
# has less or equal number of nodes
def check(f, t, delta, check_val=True, graph_input=False):
if graph_input:
fx_g = f
else:
fx_g = make_fx(f)(t)
new_graph = fx_graph_cse(fx_g.graph)
new_g = fx.GraphModule(fx_g, new_graph)
# the number of nodes decrease/ or stay the same
old_num_nodes = len(fx_g.graph.nodes)
new_num_nodes = len(new_graph.nodes)
if delta == -1:
assert old_num_nodes >= new_num_nodes, (
f"number of nodes increased {old_num_nodes}, {new_num_nodes}")
else:
assert old_num_nodes == new_num_nodes + delta, (
f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}")
# a second pass should not reduce more nodes
pass_2_graph = fx_graph_cse(new_graph)
pass_2_num_nodes = len(pass_2_graph.nodes)
assert pass_2_num_nodes == new_num_nodes, (
f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}")
# check correctness
if check_val:
true_result = fx_g(t)
our_result = new_g(t)
if true_result is None: # both return None
assert our_result is None, f"true result is None, CSE result is {our_result}"
else: # results returned are the same
assert torch.all(true_result == our_result), (
f"results are different {true_result}, {our_result}") # check results are the same
class NoChangeTestCase(TestCase):
def test_nochange(self):
def f(x):
a = x + 1
b = x + a
a = x
d = x + a
return b + d
t = torch.randn(2, 2)
check(f, t, 0)
def test_empty(self):
def f(x):
pass
t = torch.randn(2, 2)
check(f, t, 0)
def test_rand_like(self):
def f(x):
a = torch.rand_like(x)
b = torch.rand_like(x)
return a + b
t = torch.randn(2, 2)
check(f, t, 0, check_val=False)
def test_rand_n(self):
def f(x):
a = torch.randn(4)
b = torch.randn(4)
return a + b
t = torch.randn(2, 2)
check(f, t, 0, check_val=False)
class ReduceTestCase(TestCase):
def test_immutable_list_type(self):
def f(x):
a = x.sum(dim=1)
b = x.sum(dim=1)
c = x.sum()
d = x.sum()
return a + b + c + d
t = torch.randn(2, 2)
check(f, t, 2)
def test_immutable_list_multiple_entries(self):
def f(x):
a = x.sum(dim=[0, 1])
b = x.sum(dim=[0, 1])
c = x.sum(dim=1)
d = x.sum(dim=1)
return a + b + c + d
t = torch.randn(2, 2)
check(f, t, 2)
def test_simple(self):
def f(x):
a = x.cos()
b = x.cos()
c = a + a
d = b + b
return c + d
t = torch.randn(2, 2)
check(f, t, 2)
def test_simple_2(self):
def f(x):
a = x.cos().sin()
b = x.cos().sin()
c = a + a
d = b + b
return c + d
t = torch.randn(1)
check(f, t, 3)
def test_two_args_default(self):
def f(x):
a = x.sum(dim=1)
b = x.sum(dim=1, keepdim=False)
c = x.sum(dim=1, keepdim=False)
d = x.sum(dim=1)
return a + b + c + d
t = torch.randn(2, 2)
check(f, t, 3)
def test_two_args(self):
def f(x):
a = x.sum(dim=1)
b = x.sum(dim=1, keepdim=True)
c = x.sum(dim=1, keepdim=True)
d = x.sum(dim=1)
return a + b + c + d
t = torch.randn(2, 2)
check(f, t, 2)
def test_simple_multiple_same_ops(self):
def f(x):
a = x.sum()
b = x.sum()
c = x.sum()
d = x.sum()
return a + b + c + d
t = torch.randn(2, 2)
check(f, t, 3)
def test_nested_immutable_list_type(self):
def f(x):
a = torch.cat((x, x))
b = torch.cat((x, x))
return a + b
t = torch.randn(2, 2)
check(f, t, 1)
def test_kwarg(self):
def f(x):
a = torch.ones_like(x)
b = torch.ones_like(x)
return a + b
t = torch.randn(2, 2)
check(f, t, 1)
class RandomOpTestCase(TestCase):
def test_random(self):
def f(x):
vals = [x]
ops = [torch.clone, torch.cos, torch.tanh, torch.nn.functional.gelu]
for _ in range(100):
new_val = random.choice(ops)(random.choice(vals))
vals.append(new_val)
return vals[-1]
fx_g = fx.symbolic_trace(f)
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
t = torch.randn(2, 2)
for _ in range(30):
check(fx_g, t, -1, graph_input=True)
if __name__ == "__main__":
run_tests()