forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_recompile_ux.py
201 lines (166 loc) · 7.42 KB
/
test_recompile_ux.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# Owner(s): ["module: dynamo"]
import unittest
import weakref
import torch
import torch._dynamo
import torch._dynamo.config
import torch._dynamo.test_case
import torch._dynamo.testing
class RecompileUxTests(torch._dynamo.test_case.TestCase):
# TODO(whc) dynamo actualy recompiles one more time than the cache limit
cache_limit = 1
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._exit_stack.enter_context(
torch._dynamo.config.patch("cache_size_limit", cls.cache_limit)
)
def test_drop_cache_on_skip(self):
def model(x, i):
return x + i
attached = False
triggered = False
def trigger():
nonlocal triggered
triggered = True
def compiler(gm, input):
nonlocal attached
f = gm.forward
assert not attached
# NB: making this a weakref.ref causes the cycle to no
# longer be promptly GC'ed
weakref.finalize(f, trigger)
attached = True
return f
x = torch.randn(2)
for i in range(2):
opt_model = torch._dynamo.optimize(compiler)(model)
opt_model(x, i)
self.assertTrue(triggered)
def test_loop_torture(self):
def loop_torture(input, iters):
out = input
# randint itself causes one graph break
for _ in range(iters):
out += input
return out
compile_counter = torch._dynamo.testing.CompileCounter()
for _ in range(10):
x = torch.randn(3)
iters = torch.randint(low=0, high=1000, size=())
opt_loop_torture = torch._dynamo.optimize(compile_counter)(loop_torture)
opt_loop_torture(x, iters)
# Currently, we recompile each time,
# We'd probably like to bail out quickly and warn
# TODO(whc) these checks fail on py37. Why?
# self.assertEqual(counters["frames"]["total"], 2 + self.cache_limit)
# self.assertEqual(counters["frames"]["ok"], 1 + self.cache_limit)
# compile_counter only sees frames that were fed to the backend compiler,
# which is a subset of counters["frames"]["ok"] -- probably becuase
# counters["frames"]["ok"] includes frames not containing torch ops?
self.assertEqual(compile_counter.frame_count, self.cache_limit)
def test_dynamic_input(self):
def model(input):
return input + input
expected_recompiles = 2
compile_counter = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch("cache_size_limit", expected_recompiles):
with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
for _ in range(10):
bsz = torch.randint(low=0, high=1000, size=())
x = torch.randn((bsz, 3, 4))
opt_model = torch._dynamo.optimize(compile_counter)(model)
opt_model(x)
self.assertEqual(compile_counter.frame_count, expected_recompiles)
self.assertEqual(len(logs.records), 1)
print(logs.records[0])
self.assertTrue(
logs.records[0]
.getMessage()
.startswith("torch._dynamo hit config.cache_size_limit")
)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_nvfuser_guards(self):
# we may want to model dynamo's guards sufficiently after nvfuser's ProfilingExecutor guards
# such that we ensure dynamo is in charge of all the recompilations at the top level,
# and we could thus simplfy the underlying torchscript executor
def func(a, b, c):
return a + b * c
a = torch.rand(3, 4, 5, device="cuda")
b = torch.rand(3, 4, 5, device="cuda")
b_v = torch.rand(3, 5, 4, device="cuda").view(3, 4, 5)
b_p = torch.rand(3, 5, 4, device="cuda").permute(0, 2, 1)
c = torch.rand(3, 4, 5, device="cuda")
compile_counter = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch("cache_size_limit", 2):
opt_func = torch._dynamo.optimize(compile_counter)(func)
opt_func(a, b, c) # warmup
self.assertEqual(compile_counter.frame_count, 1)
opt_func(a, b, c) # no guard fail or recompile
self.assertEqual(compile_counter.frame_count, 1)
opt_func(a, b_v, c) # a view should not cause nvfuser recompile
self.assertEqual(compile_counter.frame_count, 1)
opt_func(a, b_p, c) # a permutation should cause recompile
self.assertEqual(compile_counter.frame_count, 2)
def assert_single_log_contains(self, logs, contains_str):
self.assertEqual(len(logs.records), 1)
self.assertTrue(
logs.records[0].getMessage().find(contains_str) > 0,
msg=f'Expected to find "{contains_str}" in log "{logs.records[0].getMessage()}"',
)
def test_verbose_tensor_check(self):
def func(a):
# Warning: choose a function here whose meta implementation lives
# entirely in C++. If you do a Python one, Dynamo will dive into
# torch._refs which is OK but it will muddy up the warnings
return torch.add(a, 4)
def cache_fail_test(cached_input, missed_input, expected_failure):
# TODO(whc) maybe its hacky to have a 'test within a test' but this seemed convenient
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
opt_func = torch._dynamo.optimize("eager")(func)
# warmup
opt_func(cached_input)
with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
opt_func = torch._dynamo.optimize("eager")(func)
opt_func(missed_input)
self.assert_single_log_contains(logs, expected_failure)
a = torch.rand(3, 4, 5)
cache_fail_test(
a, a[0:2, :, :], "tensor 'a' size mismatch at index 0. expected 3, actual 2"
)
cache_fail_test(
a,
a.clone().as_strided((3, 4, 5), stride=(1, 3, 12)),
"tensor 'a' strides mismatch at index 0. expected 20, actual 1",
)
cache_fail_test(a, a[0, :, :], "tensor 'a' rank mismatch. expected 3, actual 2")
cache_fail_test(a, a.to("meta"), "tensor 'a' dispatch key set mismatch.")
cache_fail_test(
a,
a.to(torch.float16),
"tensor 'a' dtype mismatch. expected Float, actual Half",
)
a_grad = a.clone()
a_grad.requires_grad = True
cache_fail_test(
a, a_grad, "tensor 'a' requires_grad mismatch. expected requires_grad=0"
)
def test_mismatched_type(self):
a = torch.rand(3, 4, 5)
b = torch.rand(3, 4, 5)
def func(a, b):
return a + b
opt_func = torch._dynamo.optimize("eager")(func)
# warmup
opt_func(a, b)
with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
opt_func = torch._dynamo.optimize("eager")(func)
opt_func(a, 1)
self.assert_single_log_contains(
logs, "expected type of 'b' to be a tensor type, ' but found <class 'int'>"
)
# TODO(jansel): these pass with pytest, but not with pytorch CI
# if __name__ == "__main__":
# from torch._dynamo.testing import run_tests
# run_tests()