forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_torchinductor.py
7827 lines (6646 loc) · 247 KB
/
test_torchinductor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Owner(s): ["module: inductor"]
import contextlib
import dataclasses
import functools
import importlib
import itertools
import os
import random
import sys
import typing
import unittest
import weakref
from typing import Callable
from unittest.mock import patch
import numpy as np
import sympy
import torch
import torch._dynamo
import torch.nn as nn
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.testing import rand_strided, same
from torch._inductor.codegen.cpp import CppVecKernelChecker
from torch._inductor.graph import GraphLowering
from torch._inductor.ir import InterpreterShim
from torch._inductor.utils import run_and_get_triton_code
from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn import functional as F
from torch.testing import make_tensor
from torch.testing._internal.common_dtype import all_types
from torch.testing._internal.common_utils import (
IS_CI,
IS_MACOS,
IS_WINDOWS,
IS_X86,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
TEST_WITH_SLOW,
TestCase as TorchTestCase,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten, tree_unflatten
if IS_WINDOWS and IS_CI:
sys.stderr.write(
"Windows CI does not have necessary dependencies for test_torchinductor yet\n"
)
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")
importlib.import_module("functorch")
importlib.import_module("filelock")
from functorch.compile import config as functorch_config
from torch._decomp import get_decompositions
from torch._inductor import codecache, config, metrics, test_operators
from torch._inductor.codegen.cpp import cexpr, CppOverrides, CppVecOverrides
from torch._inductor.codegen.triton import texpr
from torch._inductor.codegen.wrapper import pexpr
from torch._inductor.compile_fx import (
compile_fx,
compile_fx_inner,
complex_memory_overlap,
)
from torch._inductor.ir import ModularIndexing
from torch._inductor.sizevars import SizeVarAllocator
from torch._inductor.utils import has_torchvision_roi_align, timed
from torch.fx.experimental.symbolic_shapes import FloorDiv
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
HAS_MULTIGPU = HAS_CUDA and torch.cuda.device_count() >= 2
HAS_AVX2 = "fbgemm" in torch.backends.quantized.supported_engines
aten = torch.ops.aten
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
requires_multigpu = functools.partial(
unittest.skipIf, not HAS_MULTIGPU, "requires multiple cuda devices"
)
slow = functools.partial(unittest.skipIf, not TEST_WITH_SLOW, "too slow")
skip_if_x86_mac = functools.partial(
unittest.skipIf, IS_MACOS and IS_X86, "Does not work on x86 Mac"
)
vec_dtypes = [torch.float, torch.bfloat16]
# For OneDNN bf16 path, OneDNN requires the cpu has intel avx512 with avx512bw,
# avx512vl, and avx512dq at least. So we will skip the test case if one processor
# is not meet the requirement.
@functools.lru_cache(maxsize=None)
def has_bf16_support():
import sys
if sys.platform != "linux":
return False
with open("/proc/cpuinfo", encoding="ascii") as f:
lines = f.read()
return all(word in lines for word in ["avx512bw", "avx512vl", "avx512dq"])
unary_list = [
torch.nn.ReLU(),
torch.nn.Sigmoid(),
torch.nn.Tanh(),
torch.nn.Hardswish(),
torch.nn.LeakyReLU(0.1, inplace=False),
torch.nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False),
torch.nn.GELU(approximate="none"),
torch.nn.GELU(approximate="tanh"),
torch.nn.ReLU6(),
torch.nn.SiLU(),
torch.nn.Hardsigmoid(),
lambda x: F.relu(x),
lambda x: F.sigmoid(x),
lambda x: F.tanh(x),
lambda x: F.hardswish(x),
lambda x: F.leaky_relu(x, 0.1),
lambda x: F.hardtanh(x, min_val=-0.5, max_val=4),
lambda x: F.gelu(x, approximate="none"),
lambda x: F.gelu(x, approximate="tanh"),
lambda x: F.relu6(x),
lambda x: F.silu(x),
lambda x: F.hardsigmoid(x),
lambda x: torch.relu(x),
lambda x: torch.sigmoid(x),
lambda x: torch.tanh(x),
lambda x: x.relu(),
lambda x: x.sigmoid(),
lambda x: x.tanh(),
]
binary_list = [
lambda x, y: torch.add(x, y), # call_function
lambda x, y: torch.add(y, x), # call_function
lambda x, y: x.add(y), # call_method
lambda x, y: x.add_(y), # call_method
lambda x, y: torch.sub(x, y), # call_function
lambda x, y: x.sub(y), # call_method
lambda x, y: x.sub_(y), # call_method
]
def requires_decomp(fn):
"""Decorator to disable test if a decomp is missing"""
def wrap_test(test):
@functools.wraps(test)
def maybe_test(*args, **kwargs):
if len(get_decompositions([fn])) == 0:
raise unittest.SkipTest(f"requires decomp for {fn.__name__}")
return test(*args, **kwargs)
return maybe_test
return wrap_test
class TestCase(TorchTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._stack = contextlib.ExitStack()
cls._stack.enter_context(
config.patch(
{
"debug": True,
"cpp.min_chunk_size": 1,
"triton.autotune_pointwise": False, # too slow
"implicit_fallbacks": False,
}
)
)
@classmethod
def tearDownClass(cls):
cls._stack.close()
super().tearDownClass()
def setUp(self):
torch._dynamo.reset()
super().setUp()
def tearDown(self):
super().tearDown()
torch._dynamo.reset()
class ToTuple(torch.nn.Module):
def forward(self, x):
return (x,)
@dataclasses.dataclass
class InputGen:
n: int
device: str
def dense(self):
return torch.randn((self.n, self.n), device=self.device)
def transposed(self):
return self.dense().transpose(0, 1)
def strided(self):
return torch.randn((self.n * 2, self.n * 3), device=self.device)[
self.n :, self.n :: 2
]
def broadcast1(self):
return torch.randn((self.n,), device=self.device)
def broadcast2(self):
return torch.randn((1, self.n, 1), device=self.device)
def broadcast3(self):
return torch.randn((1,), device=self.device)
def double(self):
return torch.randn((self.n, self.n), device=self.device, dtype=torch.double)
def int(self):
return torch.arange(self.n, device=self.device, dtype=torch.int32)
def compute_grads(args, kwrags, results, grads):
def gather_leaf_tensors(args, kwargs):
args, _ = tree_flatten(args)
kwargs, _ = tree_flatten(kwargs)
args = args + kwargs
leaf_tensors = [
arg for arg in args if isinstance(arg, torch.Tensor) and arg.requires_grad
]
return leaf_tensors
flat_results, _ = tree_flatten(results)
flat_diff_results = [r for r in flat_results if r.requires_grad]
assert len(flat_diff_results) > 0
leaf_tensors = gather_leaf_tensors(args, kwrags)
assert len(leaf_tensors) > 0
return torch.autograd.grad(
flat_diff_results,
leaf_tensors,
grads,
allow_unused=True,
retain_graph=True,
)
def clone_preserve_strides(x):
if not isinstance(x, torch.Tensor):
return x
buffer = torch.as_strided(
x, (x.untyped_storage().size() // x.element_size(),), (1,), 0
).clone()
out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
return out
@patch.object(config, "debug", True)
def run_and_get_cpp_code(fn, args):
torch._dynamo.reset()
import io
from contextlib import redirect_stdout
f = io.StringIO()
with redirect_stdout(f):
fn(*args)
s = f.getvalue()
return s
def check_model(
self: TestCase,
model,
example_inputs,
kwargs=None,
*,
atol=None,
rtol=None,
check_lowp=True,
exact_dtype=True,
nopython=True,
copy_to_cuda=True,
reference_in_float=True,
assert_equal=True,
check_gradient=False,
):
kwargs = kwargs or {}
torch._dynamo.reset()
ref_inputs = [clone_preserve_strides(x) for x in example_inputs]
ref_kwargs = kwargs
has_lowp_args = False
original_lowp_dtype = torch.half
if reference_in_float:
# check_lowp is ignored here, it's kept just to be able to call `common` with extra arg
def upcast_fn(x):
nonlocal has_lowp_args
if isinstance(x, torch.Tensor) and (
x.dtype == torch.float16 or x.dtype == torch.bfloat16
):
has_lowp_args = True
return x.float()
else:
return x
def get_original_lowp_dtype(example_inputs):
dtypes = [x.dtype for x in example_inputs if isinstance(x, torch.Tensor)]
dtype_set = set(dtypes)
return dtype_set.pop() if len(dtype_set) == 1 else torch.half
ref_inputs = list(map(upcast_fn, example_inputs))
ref_kwargs = {k: upcast_fn(v) for k, v in kwargs.items()}
if has_lowp_args:
original_lowp_dtype = get_original_lowp_dtype(example_inputs)
if hasattr(model, "to"):
model = model.to(torch.float)
torch.manual_seed(0)
correct = model(*ref_inputs, **ref_kwargs)
# downcast the model back if needed
if reference_in_float and has_lowp_args:
if hasattr(model, "to"):
model = model.to(original_lowp_dtype)
torch._inductor.metrics.reset()
called = False
def compile_fx_wrapper(model_, example_inputs_):
nonlocal called
called = True
return compile_fx(model_, example_inputs_)
def run(*ex, **kwargs):
return model(*ex, **kwargs)
run = torch._dynamo.optimize(compile_fx_wrapper, nopython=nopython)(run)
torch.manual_seed(0)
actual = run(*example_inputs, **kwargs)
# if not called:
# exp = torch._dynamo.explain(run, *example_inputs)
# print("Explain:", exp[0])
# for graph in exp[2]:
# print("Graph", graph)
assert called, "Ran graph without calling compile_fx"
assert type(actual) == type(correct)
correct_flat, correct_spec = tree_flatten(correct)
actual_flat, _ = tree_flatten(actual)
if reference_in_float:
correct_flat = tuple(
y.to(x.dtype)
if isinstance(y, torch.Tensor) and y.dtype.is_floating_point
else y
for x, y in zip(actual_flat, correct_flat)
)
correct = tree_unflatten(correct_flat, correct_spec)
if assert_equal:
self.assertEqual(
actual,
correct,
atol=atol,
rtol=rtol,
equal_nan=True,
exact_dtype=exact_dtype,
)
# In case of input mutations, check that inputs are the same
self.assertEqual(
ref_inputs,
example_inputs,
atol=atol,
rtol=rtol,
equal_nan=True,
# our testing sometimes uses higher precision inputs for the reference
exact_dtype=False,
)
else:
for correct_val, actual_val in zip(correct_flat, actual_flat):
if isinstance(correct_val, torch.Tensor):
assert correct_val.device == actual_val.device
assert correct_val.size() == actual_val.size()
assert correct_val.stride() == actual_val.stride()
assert correct_val.layout == actual_val.layout
if exact_dtype:
assert correct_val.dtype == actual_val.dtype
if check_gradient:
# generate random unit norm gradients
grads = [
torch.rand(r.shape, device=r.device, dtype=r.dtype)
for r in correct_flat
if r.requires_grad
]
for g in grads:
g /= g.norm()
correct_grad = compute_grads(ref_inputs, ref_kwargs, correct, grads)
actual_grad = compute_grads(example_inputs, kwargs, actual, grads)
self.assertEqual(
actual_grad,
correct_grad,
atol=atol,
rtol=rtol,
equal_nan=True,
exact_dtype=exact_dtype,
)
torch._dynamo.reset()
@torch._inductor.config.patch("triton.cudagraphs", False)
def check_model_cuda(
self: TestCase,
model,
example_inputs,
kwargs=None,
*,
atol=None,
rtol=None,
check_lowp=True,
exact_dtype=True,
nopython=True,
copy_to_cuda=True,
reference_in_float=True,
assert_equal=True,
check_gradient=False,
):
kwargs = kwargs or {}
if hasattr(model, "to"):
model = model.to("cuda")
def copy_fn(x):
# preserve strides of the input on the device
if not isinstance(x, torch.Tensor):
return x
return torch.empty_strided(
x.size(), x.stride(), device="cuda", dtype=x.dtype
).copy_(x)
if copy_to_cuda:
example_inputs = tuple(copy_fn(x) for x in example_inputs)
check_model(
self,
model,
example_inputs,
kwargs,
atol=atol,
rtol=rtol,
exact_dtype=exact_dtype,
nopython=nopython,
reference_in_float=reference_in_float,
assert_equal=assert_equal,
check_gradient=check_gradient,
)
if check_lowp:
def downcast_fn(x):
if not isinstance(x, torch.Tensor) or not x.dtype == torch.float:
return x
return torch.empty_strided(
x.size(), x.stride(), device="cuda", dtype=torch.half
).copy_(x)
example_inputs = list(map(downcast_fn, example_inputs))
if hasattr(model, "to"):
model = model.to(torch.half)
if rtol is not None:
rtol = max(2e-3, rtol)
check_model(
self,
model,
example_inputs,
kwargs,
atol=atol,
rtol=rtol,
exact_dtype=exact_dtype,
nopython=nopython,
reference_in_float=reference_in_float,
assert_equal=assert_equal,
check_gradient=check_gradient,
)
class SweepInputs2:
input_gen_types1 = [
"dense",
"transposed",
"strided",
"broadcast1",
"broadcast2",
"broadcast3",
"double",
"int",
]
input_gen_types2 = input_gen_types1
gen = None
@staticmethod
def kernel(a, b):
return (a + b,)
@classmethod
def gen_template(cls, name1, name2):
def test(self):
check_model(
self,
cls.kernel,
(
getattr(cls.gen, name1)(),
getattr(cls.gen, name2)(),
),
)
test.__name__ = f"test_{cls.gen.device}_{name1}_{name2}"
setattr(cls, test.__name__, test)
@classmethod
def populate(cls):
for name1 in cls.input_gen_types1:
for name2 in cls.input_gen_types2:
cls.gen_template(name1, name2)
class TestIndexingSimplification(TorchTestCase):
def test_indexing_simplification(self):
sizevars = SizeVarAllocator()
i0 = sympy.Symbol("i0", integer=True)
i1 = sympy.Symbol("i1", integer=True)
i2 = sympy.Symbol("i2", integer=True)
r3 = sympy.Symbol("r3", integer=True)
var_ranges = {i0: 3136, i1: 64, i2: 32, r3: 3}
expr = (
128 * i2
+ ModularIndexing(i1, 1, 64)
+ 64 * ModularIndexing(i1 + 64 * r3, 64, 2)
)
# check that `i1//64` is removed when i1 is always less than 64,
# and the next simplificaton doesn't happen
self.assertEqual(
sizevars.simplify_with_ranges(expr, var_ranges),
i1 + 128 * i2 + 64 * ModularIndexing(r3, 1, 2),
)
# all the modular indexing should be removed when the body cant be larger than the modulus
var_ranges[r3] = 2
self.assertEqual(
sizevars.simplify_with_ranges(expr, var_ranges), i1 + 128 * i2 + 64 * r3
)
# if there are negative terms in ModularIndexing base, we cannot replace it with FloorDiv
expr = ModularIndexing(i1 - 15, 1, 64)
self.assertEqual(
sizevars.simplify_with_ranges(expr, var_ranges),
ModularIndexing(i1 - 15, 1, 64),
)
# small terms should be kept if the rest is not guaranteed to be divisible
self.assertEqual(
sizevars.simplify_with_ranges(FloorDiv(r3 + i2 + i1, 32), var_ranges),
FloorDiv(r3 + i2 + i1, 32),
)
expr = ModularIndexing(2 * i2 + r3, 1, 64)
# modular indexing is removed if base is smaller than modulo
self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), 2 * i2 + r3)
# check the same thing but with symbolic divisor
self.assertEqual(FloorDiv(r3 * i0, r3), i0)
self.assertEqual(ModularIndexing(r3 * i0, r3, 10), ModularIndexing(i0, 1, 10))
# (10*i) % 10 is always zero and should get optimized away
self.assertEqual(
ModularIndexing(i0 + i1 * 10, 1, 10), ModularIndexing(i0, 1, 10)
)
# ((20*i)//2) % 10 is always zero and should get optimized away
self.assertEqual(
ModularIndexing(i0 + i1 * 20, 2, 10), ModularIndexing(i0, 2, 10)
)
# the same things happens with symbolic divisor
self.assertEqual(
ModularIndexing(i0 + i1 * i2 * r3, i2, r3), ModularIndexing(i0, i2, r3)
)
# if there are negative terms, we cannot optimize away zero terms due to https://github.com/openai/triton/issues/619
self.assertEqual(
ModularIndexing(-i0 + i1 * 20, 2, 10), ModularIndexing(-i0 + i1 * 20, 2, 10)
)
self.assertEqual(
ModularIndexing(-15 + i1 * 20, 2, 10), ModularIndexing(-15 + i1 * 20, 2, 10)
)
# Constant fold from divisor into base
self.assertEqual(ModularIndexing(i0 * 4, 2, 10), ModularIndexing(i0 * 2, 1, 10))
self.assertEqual(FloorDiv(i0 * 4, 2), i0 * 2)
# Nested modular indexing is correctly simplified
var_ranges = {"i1": 13, "i2": 121}
expr = ModularIndexing(ModularIndexing(121 * i1 + i2, 1, 784), 1, 28)
self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr)
expr = ModularIndexing(ModularIndexing(121 * i1 + i2, 1, 784) + 1, 1, 28)
self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr)
var_ranges = {"i2": 784}
expr = ModularIndexing(ModularIndexing(i2, 1, 28), 7, 4)
expected = FloorDiv(ModularIndexing(i2, 1, 28), 7)
self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expected)
expr = ModularIndexing(ModularIndexing(i2, 1, 28) + 1, 7, 4)
self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr)
def test_indexing_join(self):
sizevars = SizeVarAllocator()
i0 = sympy.Symbol("i0", integer=True)
i1 = sympy.Symbol("i1", integer=True)
i2 = sympy.Symbol("i2", integer=True)
# join two ModularIndexing calls into one larger one when possible
expr1 = ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4)
self.assertEqual(
sizevars.simplify_with_ranges(expr1, {}), ModularIndexing(i0, 1, 128)
)
# it should also work with a scale
self.assertEqual(
sizevars.simplify_with_ranges(2 * expr1, {}),
2 * ModularIndexing(i0, 1, 128),
)
# it should work when divisor is not 1
expr2 = ModularIndexing(i0, 3, 32) + 32 * ModularIndexing(i0, 32 * 3, 4)
simplified = sizevars.simplify_with_ranges(expr2, {})
self.assertEqual(simplified, ModularIndexing(i0, 3, 128))
self.assertEqual(expr2.subs({i0: 39485}), simplified.subs({i0: 39485}))
# it should not happen in this case as the modulus is wrong
expr3 = ModularIndexing(i0, 1, 30) + 32 * ModularIndexing(i0, 32, 4)
self.assertEqual(sizevars.simplify_with_ranges(expr3, {}), expr3)
# check that it also works with a modulus>1
expr4 = ModularIndexing(i0, 10, i1) + i1 * ModularIndexing(i0, i1 * 10, i2)
res0 = expr4.subs({i0: 24056, i1: 13, i2: 19})
simplified = sizevars.simplify_with_ranges(expr4, {})
res1 = simplified.subs({i0: 24056, i1: 13, i2: 19})
self.assertEqual(res0, res1)
self.assertEqual(simplified, ModularIndexing(i0, 10, i1 * i2))
# and also works with an offset
self.assertEqual(
sizevars.simplify_with_ranges(expr4 + 10, {}),
ModularIndexing(i0, 10, i1 * i2) + 10,
)
# works for ModularIndexing + FloorDiv
expr5 = 197 * FloorDiv(i0, 197) + ModularIndexing(i0, 1, 197)
simplified = sizevars.simplify_with_ranges(expr5, {})
self.assertEqual(simplified, i0)
self.assertEqual(expr5.subs({i0: 39485}), simplified.subs({i0: 39485}))
# works with a scale
self.assertEqual(
sizevars.simplify_with_ranges(2 * expr5, {}),
2 * i0,
)
# divisor != 1
expr6 = 197 * FloorDiv(i0, 197 * 3) + ModularIndexing(i0, 3, 197)
simplified = sizevars.simplify_with_ranges(expr6, {})
self.assertEqual(simplified, FloorDiv(i0, 3))
self.assertEqual(expr6.subs({i0: 39485}), simplified.subs({i0: 39485}))
class CommonTemplate:
def test_bool(self):
def fn(a, b):
return (
a + b,
a * b,
a & b,
a | b,
a ^ b,
torch.logical_and(a, b),
torch.logical_or(a, b),
torch.logical_not(a),
torch.sign(b),
)
self.common(
fn,
(
torch.tensor([True, False, True, False]),
torch.tensor([False, False, True, True]),
),
)
def test_add_const_int(self):
def fn(a):
return (a + 1, torch.add(a, 1, alpha=2))
self.common(fn, (torch.randn(32),))
def test_add_const_float(self):
def fn(a):
return (a + 1.5,)
self.common(fn, (torch.randn(32),))
def test_add_inplace_permuted(self):
def fn(x, y):
return x.add_(y)
x = torch.ones([2, 12, 13, 17]).transpose(1, 2)
y = torch.randn([2, 13, 1, 17])
self.common(fn, (x, y))
def test_concat_add_inplace(self):
def fn(x, y, z):
return torch.cat([x, y], dim=1).add_(z)
x = torch.randn([2, 12, 14, 14])
y = torch.randn([2, 12, 14, 14])
z = torch.randn([2, 24, 14, 14])
self.common(fn, (x, y, z))
def test_abs(self):
def fn(a):
return (a / (torch.abs(a) + 1),)
self.common(fn, (torch.randn(17),))
def test_sgn(self):
def fn(a):
return torch.sgn(a), torch.sgn(a + 1) - 1
self.common(fn, [torch.linspace(-10, 10, 41)])
def test_randn_generator(self):
def fn(a, generator):
torch.randn([20, 20], generator=generator, device=a.device)
self.common(fn, (torch.linspace(-10, 10, 41), None))
# generator not yet supported in dynamo
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "Generator"):
self.common(fn, (torch.linspace(-10, 10, 41), torch.Generator(self.device)))
def test_sgn_extremal(self):
def fn(a):
return (torch.sgn(a),)
self.common(fn, [torch.tensor([np.nan, np.inf, -np.inf, 0])])
def test_max_min(self):
def fn(a, b):
return (torch.maximum(a, b), torch.minimum(a, b))
self.common(fn, (torch.randn(8), torch.randn(8)))
t1 = torch.randn(8)
t1[0] = float("nan")
t2 = torch.randn(8)
t2[1] = float("nan")
self.common(fn, (t1, t2))
def test_neg_max_uint8(self):
# https://github.com/pytorch/pytorch/issues/93380
def fn(a, b):
c = torch.neg(a)
return torch.maximum(b, c)
a = torch.randint(256, (1,), dtype=torch.uint8)
b = torch.randint(256, (8390,), dtype=torch.uint8)
self.common(fn, (a, b))
def test_compar(self):
def fn(x):
return x.gt(3.5), x.ge(3.5), x.eq(3.5), x.le(2.5), x.lt(3.5), x.ne(3.5)
a = torch.tensor([3])
self.common(fn, (a,))
def test_horizonal_fusion1(self):
def fn(a, b, c):
return (a + b, a - c, b * c)
self.common(
fn, (torch.randn(8, 16, 16), torch.randn(8, 16, 16), torch.randn(1, 16, 1))
)
def test_horizonal_fusion2(self):
def fn(a, b, c):
return a + 1, b + 2, c + 3
self.common(fn, (torch.randn(8, 16, 8), torch.randn(8, 16), torch.randn(16, 8)))
def test_vertical_fusion1(self):
def fn(sa, ct, p):
# From torchbench.pyhpc_equation_of_state
v17 = -3.087032500374211e-7
v18 = -1.988366587925593e-8
v19 = -1.061519070296458e-11
v20 = 1.550932729220080e-10
t15 = v19 * ct
t19 = v17 + ct * (v18 + t15) + v20 * sa
t20 = 1.0 / t19
t128 = t19 * p
return t20 + t128
self.common(
fn,
(
torch.randn(204, 204, 26),
torch.randn(204, 204, 26),
torch.randn(26),
),
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
def test_forced_buffer_realize(self):
# Test torch._test_inductor_realize forces a buffer to be realized
def fn(a):
b = test_operators.realize(a * 2)
return (b * 2,)
self.common(fn, (torch.randn(10),))
self.assertEqual(torch._inductor.metrics.ir_nodes_pre_fusion, 2)
def test_scheduler_vertical_fusion1(self):
realize = test_operators.realize
def fn(sa, ct, p):
# From torchbench.pyhpc_equation_of_state
v17 = -3.087032500374211e-7
v18 = -1.988366587925593e-8
v19 = -1.061519070296458e-11
v20 = 1.550932729220080e-10
t15 = realize(v19 * ct)
t19 = realize(v17 + ct * (v18 + t15) + v20 * sa)
t20 = realize(1.0 / t19)
t128 = realize(t19 * p)
return t20 + t128
self.common(
fn,
(
torch.randn(204, 204, 26),
torch.randn(204, 204, 26),
torch.randn(26),
),
)
self.assertEqual(torch._inductor.metrics.ir_nodes_pre_fusion, 5)
self.assertEqual(
torch._inductor.metrics.generated_kernel_count,
1 if self.device == "cuda" else 3,
)
def test_sum1(self):
def fn(a, b):
return ((a + b).sum(-1),)
self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
def test_sum2(self):
def fn(a, b):
return ((a + b).sum([1, 2]), (a + b).sum(-1))
self.common(fn, (torch.randn(8, 9, 3, 21), torch.randn(8, 9, 3, 21)))
def test_sum3(self):
def fn(a, b):
r1 = a + b
r2 = r1.sum(-1)
r3 = torch.squeeze(b) + 10
return (r1, r2, r3)
# Mismatched elements: 2 / 10 (20.0%)
# Greatest absolute difference: 0.0029296875 at index (8,) (up to 1e-05 allowed)
# Greatest relative difference: 0.0017482517482517483 at index (6,) (up to 0.001 allowed)
self.common(fn, (torch.randn(10, 10), torch.randn(1, 10)), atol=1e-5, rtol=2e-3)
def test_sum4(self):
def fn(a):
b = a + 1
c = b.sum(-1)
d = c + 3
e = d.sum(-1)
f = e + 5
return (f, e, d, c, b)
self.common(fn, (torch.randn(1, 16, 8, 8),))
def test_sum5(self):
def fn(a):
b = a + 1
c = b.sum(-1)
d = c + 3
e = d.sum(-1)
f = e + 5
return (f,)
self.common(fn, (torch.randn(1, 17, 8, 9),))
def test_reduction1(self):
def fn(a):
return (a.sum(), a.max(), a.min(), a.argmax(), a.argmin())
self.common(fn, (torch.tensor([float("-inf"), 0.0, float("inf")]),))
@skip_if_x86_mac()
def test_reduction2(self):
def fn(a):
# FIXME: a.argmax
return (a.sum(), a.max(), a.min(), a.argmin())
self.common(fn, (torch.full((4,), float("inf")),))
@skip_if_x86_mac()
def test_reduction3(self):
def fn(a):
# FIXME: a.argmin
return (a.sum(), a.max(), a.min(), a.argmax())
self.common(fn, (torch.full((4,), float("-inf")),))
def test_reduction4(self):
if self.device == "cpu":
raise unittest.SkipTest("Non-deterministic CPU results")
def fn(a):
return (a.argmax(-1), a.argmin(-1))
inputs = (torch.ones(128), torch.ones(4, 4, 1))
for i in inputs:
self.common(fn, (i,))
@config.patch(unroll_reductions_threshold=1)
def test_reduction5(self):
if self.device == "cpu":
raise unittest.SkipTest("Non-deterministic CPU results")
def fn(a):
return (a.sum(), a.max(), a.min(), a.argmax())
self.common(fn, (torch.full((4,), float("-inf")),))
def test_unroll_small_reduction(self):
def fn(x):
val1, index1 = x.min(-1)
val2, index2 = x.max(-1)
return (
val1,
index1,
val2,
index2,
x.sum(-1),
(x > 1).any(-1),
(x > 0).all(-1),
x.argmin(-1),
x.argmax(-1),
x.amin(-1),
x.amax(-1),
x.aminmax(),
)
with config.patch(unroll_reductions_threshold=8):
# small sized reductions will get unrolled
self.common(fn, (torch.randn(8, 3),))
torch._dynamo.reset()
with config.patch(unroll_reductions_threshold=1):
# make sure things also work if they aren't unrolled
self.common(fn, (torch.randn(8, 3),))
def test_multilayer_low_prec(self):
# fp16 nyi for cpu
if self.device == "cpu":
raise unittest.SkipTest("requires CUDA")
def fn(a):
return torch.mean(a)
self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float16),)))
def test_expanded_reduction(self):
if self.device == "cpu":
raise unittest.SkipTest(
"https://github.com/pytorch/torchdynamo/issues/1697"