forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_torchinductor_dynamic_shapes.py
170 lines (140 loc) · 5.09 KB
/
test_torchinductor_dynamic_shapes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Owner(s): ["module: inductor"]
import contextlib
import importlib
import math
import os
import sys
import unittest
from functools import partial
import torch
from torch._dynamo.testing import make_test_cls_with_patches
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCUDA,
)
from torch.testing._internal.common_utils import (
IS_CI,
IS_WINDOWS,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
TestCase,
)
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
if IS_WINDOWS and IS_CI:
sys.stderr.write(
"Windows CI does not have necessary dependencies for test_torchinductor_dynamic_shapes yet\n"
)
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from inductor.test_torchinductor import (
check_model,
check_model_cuda,
CommonTemplate,
copy_tests,
)
importlib.import_module("filelock")
test_skips = {
"test_cpp_wrapper_dynamic_shapes": ("cpu",),
"test_cudnn_rnn_dynamic_shapes": ("cuda",),
"test_kwargs_dynamic_shapes": ("cpu",),
# test_roi_align uses torchvision, which doesn't work with dynamic shapes
"test_roi_align_dynamic_shapes": ("cpu", "cuda"),
#
# These are from switching to specialize_int=False
#
"test_div8_dynamic_shapes": ("cpu", "cuda"), # StopIteration
# NotImplementedError: argument of type: <class 'sympy.core.add.Add'>
"test_reflection_pad2d_backward_dynamic_shapes": ("cpu", "cuda"),
"test_both_scalars_dynamic_shapes": ("cpu", "cuda"), # StopIteration
}
def make_dynamic_cls(cls):
return make_test_cls_with_patches(
cls,
"DynamicShapes",
"_dynamic_shapes",
(torch._dynamo.config, "dynamic_shapes", True),
)
DynamicShapesCommonTemplate = make_dynamic_cls(CommonTemplate)
if HAS_CPU:
class DynamicShapesCpuTests(TestCase):
common = check_model
device = "cpu"
copy_tests(DynamicShapesCommonTemplate, DynamicShapesCpuTests, "cpu", test_skips)
if HAS_CUDA and not TEST_WITH_ASAN:
class DynamicShapesCudaTests(TestCase):
common = check_model_cuda
device = "cuda"
copy_tests(DynamicShapesCommonTemplate, DynamicShapesCudaTests, "cuda", test_skips)
class TestInductorDynamic(TestCase):
compile_fn = partial(torch.compile, dynamic=True)
def setUp(self):
# HAS_CUDA also checks compute capability to skip tests
# on older devices
if self.device_type == "cuda" and not HAS_CUDA:
self.skipTest("Triton not available")
torch._dynamo.reset()
super(TestCase, self).setUp()
# this should be in setUpClass, but device-generic tests
# don't work with setUpClass well (non-deterministically the wrong setUpClass is resolved),
# so put it in test setUp, it's cheap
self._stack = contextlib.ExitStack()
self._stack.enter_context(
torch._inductor.config.patch(
{
"debug": False,
"cpp.min_chunk_size": 1,
"triton.autotune_pointwise": False, # too slow
"implicit_fallbacks": False,
}
)
)
def tearDown(self):
self._stack.close()
super(TestCase, self).tearDown()
torch._dynamo.reset()
def test_arange_dynamic(self, device):
def fn(a):
batch_size = a.numel()
max_len = a.max()
return ~(
torch.arange(0, max_len, device=a.device)
.type_as(a)
.repeat(batch_size, 1)
.lt(a.unsqueeze(1))
)
a = torch.randint(10, 30, (10,), device=device)
a[0] = 29 # fix max_len
opt = self.compile_fn(fn)
res = opt(a)
ref = fn(a)
self.assertEqual(res, ref)
@onlyCUDA
def test_pad_dynamic(self, device):
def get_same_padding(x: int, k: int, s: int, d: int):
return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
def pad_same(x, k, s, d=(1, 1), value=0):
ih, iw = x.size()[-2:]
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(
iw, k[1], s[1], d[1]
)
if pad_h > 0 or pad_w > 0:
x = torch.nn.functional.pad(
x,
[pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2],
value=value,
)
return x
x = torch.randn(2, 24, 110, 110, device=device)
opt = self.compile_fn(pad_same)
res = opt(x, (5, 5), (2, 2))
ref = pad_same(x, (5, 5), (2, 2))
self.assertEqual(res, ref, atol=0, rtol=0)
instantiate_device_type_tests(TestInductorDynamic, globals())
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
if (HAS_CPU or HAS_CUDA) and not TEST_WITH_ROCM:
run_tests(needs="filelock")