forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_python_autograd.py
287 lines (219 loc) · 8.62 KB
/
test_python_autograd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
# Owner(s): ["module: dynamo"]
from typing import Callable, Dict, List, NamedTuple, Optional
import torch
import torch._dynamo
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter, same
"""
This is an example of a pure-python version of autograd implemented by
@zdevito. It represents a rather challenging test case for TorchDynamo
to push the limits of what it can do.
"""
_name: int = 0
def fresh_name() -> str:
"""create a new unique name for a variable: v0, v1, v2"""
global _name
r = f"v{_name}"
_name += 1
return r
class Variable:
def __init__(self, value: torch.Tensor, name: str = None):
self.value = value
self.name = name or fresh_name()
# We need to start with some tensors whose values were not computed
# inside the autograd. This function constructs leaf nodes.
@staticmethod
def constant(value: torch.Tensor, name: str = None):
return Variable(value, name)
def __repr__(self):
return repr(self.value)
# This performs a pointwise multiplication of a Variable, tracking gradients
def __mul__(self, rhs: "Variable") -> "Variable":
# defined later in the notebook
return operator_mul(self, rhs)
def __add__(self, rhs: "Variable") -> "Variable":
return operator_add(self, rhs)
def sum(self, name: Optional[str] = None) -> "Variable":
return operator_sum(self, name)
def expand(self, sizes: List[int]) -> "Variable":
return operator_expand(self, sizes)
class TapeEntry(NamedTuple):
# names of the inputs to the original computation
inputs: List[str]
# names of the outputs of the original computation
outputs: List[str]
# apply chain rule
propagate: "Callable[List[Variable], List[Variable]]"
gradient_tape: List[TapeEntry] = []
def reset_tape():
gradient_tape.clear()
global _name
_name = 0
def grad(L, desired_results: List[Variable]) -> List[Variable]:
# this map holds dL/dX for all values X
dL_d: Dict[str, Variable] = {}
# It starts by initializing the 'seed' dL/dL, which is 1
dL_d[L.name] = Variable(torch.ones(()))
# print(f'd{L.name} ------------------------')
# look up dL_dentries. If a variable is never used to compute the loss,
# we consider its gradient None, see the note below about zeros for more information.
def gather_grad(entries: List[str]):
return [dL_d[entry] if entry in dL_d else None for entry in entries]
# propagate the gradient information backward
for entry in reversed(gradient_tape):
dL_doutputs = gather_grad(entry.outputs)
if all(dL_doutput is None for dL_doutput in dL_doutputs):
# optimize for the case where some gradient pathways are zero. See
# The note below for more details.
continue
# perform chain rule propagation specific to each compute
dL_dinputs = entry.propagate(dL_doutputs)
# Accululate the gradient produced for each input.
# Each use of a variable produces some gradient dL_dinput for that
# use. The multivariate chain rule tells us it is safe to sum
# all the contributions together.
for input, dL_dinput in zip(entry.inputs, dL_dinputs):
if input not in dL_d:
dL_d[input] = dL_dinput
else:
dL_d[input].value += dL_dinput.value
# print some information to understand the values of each intermediate
# for name, value in dL_d.items():
# print(f'd{L.name}_d{name} = {value.name}')
# print(f'------------------------')
return gather_grad(desired.name for desired in desired_results)
def operator_mul(self: Variable, rhs: Variable) -> Variable:
if isinstance(rhs, float) and rhs == 1.0:
# peephole optimization
return self
# define forward
r = Variable(self.value * rhs.value)
# print(f'{r.name} = {self.name} * {rhs.name}')
# record what the inputs and outputs of the op were
inputs = [self.name, rhs.name]
outputs = [r.name]
# define backprop
def propagate(dL_doutputs: List[Variable]):
(dL_dr,) = dL_doutputs
dr_dself = rhs # partial derivative of r = self*rhs
dr_drhs = self # partial derivative of r = self*rhs
# chain rule propagation from outputs to inputs of multiply
dL_dself = dL_dr * dr_dself
dL_drhs = dL_dr * dr_drhs
dL_dinputs = [dL_dself, dL_drhs]
return dL_dinputs
# finally, we record the compute we did on the tape
gradient_tape.append(TapeEntry(inputs=inputs, outputs=outputs, propagate=propagate))
return r
def operator_add(self: Variable, rhs: Variable) -> Variable:
# Add follows a similar pattern to Mul, but it doesn't end up
# capturing any variables.
r = Variable(self.value + rhs.value)
# print(f'{r.name} = {self.name} + {rhs.name}')
def propagate(dL_doutputs: List[Variable]):
(dL_dr,) = dL_doutputs
dr_dself = 1.0
dr_drhs = 1.0
dL_dself = dL_dr * dr_dself
dL_drhs = dL_dr * dr_drhs
return [dL_dself, dL_drhs]
gradient_tape.append(
TapeEntry(inputs=[self.name, rhs.name], outputs=[r.name], propagate=propagate)
)
return r
def operator_sum(self: Variable, name: Optional[str]) -> "Variable":
r = Variable(torch.sum(self.value), name=name)
# print(f'{r.name} = {self.name}.sum()')
def propagate(dL_doutputs: List[Variable]):
(dL_dr,) = dL_doutputs
size = self.value.size()
return [dL_dr.expand(*size)]
gradient_tape.append(
TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate)
)
return r
def operator_expand(self: Variable, sizes: List[int]) -> "Variable":
assert self.value.dim() == 0 # only works for scalars
r = Variable(self.value.expand(sizes))
# print(f'{r.name} = {self.name}.expand({sizes})')
def propagate(dL_doutputs: List[Variable]):
(dL_dr,) = dL_doutputs
return [dL_dr.sum()]
gradient_tape.append(
TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate)
)
return r
def simple(a, b):
t = a + b
return t * b
class TestPythonAutograd(TestCase):
def _common(self, fn, expected_ops):
args1 = [torch.randn(10), torch.randn(10)]
args2 = [torch.randn(10), torch.randn(10)]
cnt = CompileCounter()
fn_dynamo = torch._dynamo.optimize_assert(cnt)(fn)
reset_tape()
res1 = fn_dynamo(*args1)
reset_tape()
res2 = fn_dynamo(*args2)
reset_tape()
self.assertTrue(same(res1, fn(*args1)))
reset_tape()
self.assertTrue(same(res2, fn(*args2)))
reset_tape()
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, expected_ops)
def test_forwards1(self):
def fn(a, b):
a = Variable.constant(a, name="a")
b = Variable.constant(b, name="b")
loss = simple(a, b).sum()
return loss
self._common(fn, 3)
def test_forwards2(self):
def fn(a, b):
reset_tape()
a = Variable.constant(a, name="a")
b = Variable.constant(b, name="b")
loss = simple(a, b).sum()
reset_tape()
return loss
self._common(fn, 3)
def test_backwards1(self):
def fn(a, b):
a = Variable.constant(a, name="a")
b = Variable.constant(b, name="b")
loss = simple(a, b).sum()
return grad(loss, [a, b])
self._common(fn, 8)
def test_backwards2(self):
def fn(a, b):
reset_tape()
a = Variable.constant(a, name="a")
b = Variable.constant(b, name="b")
loss = simple(a, b).sum()
res = grad(loss, [a, b])
reset_tape()
return res
self._common(fn, 8)
def test_split(self):
v1 = Variable.constant(torch.randn(10), name="a")
v2 = Variable.constant(torch.randn(10), name="b")
cnt = CompileCounter()
def forward(a, b):
return simple(a, b).sum()
reset_tape()
loss1 = forward(v1, v2)
grad1 = grad(loss1, [v1, v2])
reset_tape()
opt_forward = torch._dynamo.optimize_assert(cnt)(forward)
opt_grad = torch._dynamo.optimize_assert(cnt)(grad)
loss2 = opt_forward(v1, v2)
# force two frames
grad2 = opt_grad(loss2, [v1, v2])
self.assertTrue(same(loss1, loss2))
self.assertTrue(same(grad1, grad2))
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 8)
if __name__ == "__main__":
run_tests()