forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_max_autotune.py
167 lines (132 loc) · 6.12 KB
/
test_max_autotune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Owner(s): ["module: inductor"]
import torch
from torch import multiprocessing as mp
from torch._dynamo.test_case import run_tests, TestCase
from torch._inductor import config
from torch._inductor.graph import GraphLowering
from torch._inductor.ir import Buffer, FixedLayout
from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm
from torch._inductor.select_algorithm import AlgorithmSelectorCache, ChoiceCaller
from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
from torch.testing._internal.inductor_utils import HAS_CUDA
torch.set_float32_matmul_precision("high")
def benchmark_choice(choice, args, out, expected_out, timings):
result = choice.benchmark(*args, out=out)
if expected_out is not None:
torch.testing.assert_close(out, expected_out)
timings.copy_(torch.tensor(result))
class FailChoiceCaller(ChoiceCaller):
def benchmark(self, *args, out):
raise RuntimeError("This choice caller will always throw")
@instantiate_parametrized_tests
class TestDoBench(TestCase):
def _create_buffer(self, name, shape):
return Buffer(name, FixedLayout(torch.device("cuda:0"), torch.float32, shape))
def test_benchmark_choice_in_subproc(self):
gm = make_fx(
lambda: torch.zeros(2, 3)
)() # a dummy graph to construct the GraphLowering
graph = GraphLowering(gm)
# the graph handler is neede to create benchmark example value below
with V.set_graph_handler(graph):
buf1 = self._create_buffer("mat1", (2, 3))
buf2 = self._create_buffer("mat2", (3, 2))
buf3 = self._create_buffer("mat3", (2, 3))
buf4 = self._create_buffer("mat4", (3, 2))
layout = FixedLayout(torch.device("cuda:0"), torch.float32, (2, 2))
mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1)
mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2)
mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3)
mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4)
out = AlgorithmSelectorCache.benchmark_example_value(layout)
# expected_out = (mat1 @ mat2) + (mat3 @ mat4)
expected_out = None
choice = aten_mm_plus_mm.bind((buf1, buf2, buf3, buf4), layout)
# use a tensor since the mutation to a python list in a sub process
# is not synced back to the parent process
timings = torch.zeros(3, dtype=torch.float32)
ctx = mp.get_context("spawn")
child = ctx.Process(
target=benchmark_choice,
args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings),
)
child.start()
child.join()
self.assertEqual(0, child.exitcode)
print(f"timings is {timings}, out {out}, expected_out {expected_out}")
def test_benchmark_choice_fail_in_subproc(self):
gm = make_fx(
lambda: torch.zeros(2, 3)
)() # a dummy graph to construct the GraphLowering
graph = GraphLowering(gm)
# the graph handler is neede to create benchmark example value below
with V.set_graph_handler(graph):
buf1 = self._create_buffer("mat1", (2, 3))
buf2 = self._create_buffer("mat2", (3, 2))
buf3 = self._create_buffer("mat3", (2, 3))
buf4 = self._create_buffer("mat4", (3, 2))
layout = FixedLayout(torch.device("cuda:0"), torch.float32, (2, 2))
mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1)
mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2)
mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3)
mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4)
out = AlgorithmSelectorCache.benchmark_example_value(layout)
expected_out = (mat1 @ mat2) + (mat3 @ mat4)
choice = FailChoiceCaller("fail_choice_caller", [], None)
# use a tensor since python list is not synced back
timings = torch.zeros(3, dtype=torch.float32)
ctx = mp.get_context("spawn")
child = ctx.Process(
target=benchmark_choice,
args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings),
)
child.start()
child.join()
self.assertNotEqual(0, child.exitcode)
@parametrize("autotune_in_subproc", (True, False))
def test_max_autotune_mm_plus_mm(self, autotune_in_subproc):
"""
This crash previously due to a triton issue: https://github.com/openai/triton/issues/1298 .
With autotuning in subprocess, we don't crash anymore.
"""
m, n, k = 2048, 1536, 64
def mm_plus_mm(a, b, c, d):
return a @ b + c @ d
a = torch.randn(m, k).cuda()
b = torch.randn(k, n).cuda()
c = torch.randn(m, k).cuda()
d = torch.randn(k, n).cuda()
with config.patch(
{"max_autotune": True, "autotune_in_subproc": autotune_in_subproc}
):
torch.compile(mm_plus_mm)(a, b, c, d)
def test_max_autotune_regular_mm(self):
"""
Make sure autotuning mm in sub processes work without crashes.
"""
def mm(a, b):
a = torch.sin(a)
return a @ b
a = torch.randn(100, 10).cuda()
b = torch.randn(10, 100).cuda()
with config.patch({"max_autotune": True, "autotune_in_subproc": True}):
torch.compile(mm)(a, b)
def test_max_autotune_addmm(self):
"""
Make sure autotuning addmm in sub processes work without crashes.
"""
def addmm(x, a, b):
return torch.addmm(x, a, b)
x = torch.randn(100).cuda()
a = torch.randn(100, 10).cuda()
b = torch.randn(10, 100).cuda()
with config.patch({"max_autotune": True, "autotune_in_subproc": True}):
torch.compile(addmm)(x, a, b)
if __name__ == "__main__":
if HAS_CUDA:
run_tests()