forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_skip_non_tensor.py
113 lines (81 loc) · 2.57 KB
/
test_skip_non_tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Owner(s): ["module: dynamo"]
from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.test_case
from torch._dynamo.testing import CompileCounter
class SkipNonTensorTests(torch._dynamo.test_case.TestCase):
def test_add_tensor1(self):
def fn(a, b):
return a + b
counter = CompileCounter()
x = torch.randn(4)
y = 5
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
opt_fn(x, y)
assert counter.op_count == 1
def test_add_tensor2(self):
def fn(a, b):
return torch.add(a, b)
counter = CompileCounter()
x = torch.randn(4)
y = 5
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
opt_fn(x, y)
assert counter.op_count == 1
def test_add_tensor_list(self):
def fn(lst):
return lst[0] + lst[1]
counter = CompileCounter()
x = torch.randn(4)
y = 5
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
opt_fn([x, y])
assert counter.op_count == 1
def test_add_tensor_dict(self):
def fn(dt):
return dt["a"] + dt["b"]
counter = CompileCounter()
x = torch.randn(4)
y = 5
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
opt_fn({"a": x, "b": y})
assert counter.op_count == 1
def test_add_skip(self):
def fn(a, b):
return a + b
counter = CompileCounter()
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
x = 4
y = 5
opt_fn(x, y)
assert counter.op_count == 0
@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
def test_recursive_list(self):
def fn(x):
return x
counter = CompileCounter()
x = []
x.append(x)
with torch._dynamo.optimize_assert(counter):
fn(x)
assert counter.op_count == 0
@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
def test_custom_list(self):
def fn(x):
return x[0] + x[1]
counter = CompileCounter()
class Foo(list):
def __iter__(self):
raise Exception()
def __len__(self):
raise Exception()
x = Foo()
x.append(torch.randn(4))
x.append(torch.randn(4))
with torch._dynamo.optimize_assert(counter):
fn(x)
assert counter.op_count == 0
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()