forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_export_mutations.py
131 lines (101 loc) · 4.4 KB
/
test_export_mutations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
class MutationExportTests(torch._dynamo.test_case.TestCase):
def check_failure_on_export(self, mod, *args):
with self.assertRaises(AssertionError):
torch._dynamo.export(mod, *args)
def check_same_with_export(self, mod, arg):
real_result = mod(arg)
graph, _ = torch._dynamo.export(mod, arg)
result = graph(arg)
self.assertTrue(torch._dynamo.utils.same(result, real_result))
def test_module_attribute_mutation_violation_positive_1(self):
# Mutating attribute with a Tensor type
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.Tensor(3, 2)
def forward(self, x):
self.a = self.a.to(torch.float64)
return x.sum() + self.a.sum()
self.check_failure_on_export(Foo(), torch.Tensor(3, 2))
def test_module_attribute_mutation_violation_positive_2(self):
# Mutating attribute with a scalar type
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = 2
def forward(self, x):
self.a = self.a * 3
return x.sum() + self.a
self.check_failure_on_export(Foo(), torch.Tensor(3, 2))
def test_module_attribute_mutation_violation_positive_3(self):
# Setting a new attribute inside forward()
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.Tensor(3, 2)
def forward(self, x):
self.b = 2
return x.sum() + self.a.sum() + self.b
self.check_failure_on_export(Foo(), torch.Tensor(3, 2))
def test_module_attribute_mutation_violation_positive_4(self):
# Mutating attribute with an inline function
class Foo(torch.nn.Module):
def add(self, a, b):
return a + b
def forward(self, x):
self.a = self.add(1, 2) * self.add(3, 4)
return x.sum() + self.a
self.check_failure_on_export(Foo(), torch.Tensor(3, 2))
def test_module_attribute_mutation_violation_negative_1(self):
# Mutating attribute with a Tensor type inside __init__ but
# not in forward()
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.Tensor(3, 2)
def forward(self, x):
return x.sum() + self.a.to(torch.float64).sum()
self.check_same_with_export(Foo(), torch.Tensor(3, 2))
def test_module_attribute_mutation_violation_negative_2(self):
# Mutating attribute with a Tensor type inside __init__ twice
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.Tensor(3, 2)
self.a = self.a.to(torch.float64)
def forward(self, x):
return x.sum() + self.a.sum()
self.check_same_with_export(Foo(), torch.Tensor(3, 2))
def test_module_attribute_mutation_violation_negative_3(self):
# Mutating local variable inside forward()
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.Tensor(3, 2)
def forward(self, x):
b = 1
b = b * 5
return x.sum() + self.a.sum() + b
self.check_same_with_export(Foo(), torch.Tensor(3, 2))
def test_module_attribute_mutation_violation_negative_4(self):
# Mutating attribute with a Tensor type
# But not exporting but using eager mode as well as dynamo optimize mode
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.Tensor(3, 2)
def forward(self, x):
self.a = self.a.to(torch.float64)
return x.sum() + self.a.sum()
mod = Foo()
arg = torch.Tensor(3, 2)
real_result = mod(arg)
opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod)
self.assertTrue(torch._dynamo.utils.same(opt_mod(arg), real_result))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()