forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfor_loop_test.py
409 lines (348 loc) · 13.8 KB
/
for_loop_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
from jax import random
from jax._src import test_util as jtu
from jax._src.lax.control_flow import for_loop
import jax.numpy as jnp
jax.config.parse_flags_with_absl()
def remat_of_for_loop(nsteps, body, state, **kwargs):
return jax.remat(lambda state: for_loop.for_loop(nsteps, body, state,
**kwargs))(state)
def nested_for_loop(nsteps, body, state, **kwargs):
def outer_body(_, refs):
def inner_body(i, _):
body(i, refs)
return
for_loop.for_loop(nsteps, inner_body, ())
return for_loop.for_loop(1, outer_body, state)
FOR_LOOP_IMPLS = [
(for_loop.for_loop, 'for_loop'),
(jax.jit(for_loop.for_loop, static_argnums=(0, 1)), 'jit_for_loop'),
(remat_of_for_loop, 'remat_for_loop'),
(nested_for_loop, 'nested_for_loop'),
(partial(for_loop.for_loop, unroll=3), 'unrolled_for_loop'),
]
def _for_loop_impls(f):
return parameterized.named_parameters(
dict(testcase_name=impl_name, for_impl=for_impl)
for for_impl, impl_name in FOR_LOOP_IMPLS
)(f)
class ForLoopTest(jtu.JaxTestCase):
@_for_loop_impls
def test_for_loop_impl_trivial(self, for_impl):
out = for_impl(5, lambda i, _: None, None)
self.assertIsNone(out)
@_for_loop_impls
def test_for_loop_can_write_to_ref(self, for_impl):
def body(_, x_ref):
x_ref[()] = jnp.float32(1.)
out = for_impl(1, body, jnp.float32(0.))
self.assertEqual(out, 1.)
def body2(i, x_ref):
x_ref[()] = jnp.float32(i)
out = for_impl(2, body2, jnp.float32(0.))
self.assertEqual(out, 1.)
def body3(i, x_ref):
x_ref[()] = jnp.float32(i) * 2.
out = for_impl(2, body3, jnp.float32(0.))
self.assertEqual(out, 2.)
@_for_loop_impls
def test_for_loop_can_write_to_multiple_refs(self, for_impl):
def body(_, refs):
x_ref, y_ref = refs
x_ref[()] = jnp.float32(1.)
y_ref[()] = jnp.float32(2.)
x, y = for_impl(1, body, (jnp.float32(0.), jnp.float32(0.)))
self.assertEqual(x, 1.)
self.assertEqual(y, 2.)
@_for_loop_impls
def test_for_loop_can_read_from_ref(self, for_impl):
def body(_, x_ref):
x_ref[()] # pylint: disable=pointless-statement
x = for_impl(1, body, jnp.float32(0.))
self.assertEqual(x, 0.)
@_for_loop_impls
def test_for_loop_can_read_from_and_write_to_ref(self, for_impl):
def body(_, x_ref):
x = x_ref[()]
x_ref[()] = x + jnp.float32(1.)
x = for_impl(5, body, jnp.float32(0.))
self.assertEqual(x, 5.)
@_for_loop_impls
def test_for_loop_can_read_from_and_write_to_refs(self, for_impl):
def body2(_, refs):
x_ref, y_ref = refs
x = x_ref[()]
y_ref[()] = x + 1.
x_ref[()] = x + 1.
x, y = for_impl(5, body2, (0., 0.))
self.assertEqual(x, 5.)
self.assertEqual(y, 5.)
@_for_loop_impls
def test_for_loop_can_read_from_and_write_to_ref_slice(self, for_impl):
def body(i, x_ref):
x = x_ref[i]
x_ref[i] = x + jnp.float32(1.)
x = for_impl(4, body, jnp.ones(4, jnp.float32))
np.testing.assert_allclose(x, 2 * jnp.ones(4, jnp.float32))
def body2(i, x_ref):
x = x_ref[i, 0]
x_ref[i, 1] = x + x_ref[i, 1]
x = for_impl(4, body2, jnp.arange(8.).reshape((4, 2)))
np.testing.assert_allclose(
x, jnp.array([[0., 1.], [2., 5.], [4., 9.], [6., 13.]]))
@_for_loop_impls
@jax.legacy_prng_key('allow')
def test_for_loop_can_implement_cumsum(self, for_impl):
def cumsum(x):
def body(i, refs):
x_ref, accum_ref = refs
accum_ref[i + 1] = accum_ref[i] + x_ref[i]
accum = jnp.zeros(x.shape[0] + 1, x.dtype)
_, accum_out = for_impl(x.shape[0], body, (x, accum))
return accum_out[1:]
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (8,))
np.testing.assert_allclose(cumsum(x), jnp.cumsum(x), rtol=1e-6)
def for_body_swap(i, refs):
a_ref, b_ref = refs
a, b = a_ref[i], b_ref[i]
b_ref[i] = a
a_ref[i] = b
def swap_ref(a, b):
return b, a
def for_body_swap_swap(i, refs):
for_body_swap(i, refs)
for_body_swap(i, refs)
swap_swap_ref = lambda a, b: (a, b)
def for_body_sincos(i, refs):
a_ref, b_ref = refs
a = a_ref[i]
b_ref[i] = jnp.sin(jnp.cos(a))
sincos_ref = lambda x, y: (x, jnp.sin(jnp.cos(x)))
def for_body_sincostan(i, refs):
a_ref, b_ref = refs
a = a_ref[i]
b_ref[i] = jnp.tan(jnp.sin(jnp.cos(a)))
sincostan_ref = lambda x, y: (x, jnp.tan(jnp.sin(jnp.cos(x))))
def for_body_accum(i, refs):
x_ref, accum_ref = refs
accum_ref[i + 1] = accum_ref[i] + x_ref[i]
def accum_ref(x, accum):
for i in range(x.shape[0] - 1):
accum = accum.at[i + 1].set(accum[i] + x[i])
return x, accum
def for_body_sin_sq(i, refs):
x_ref, y_ref = refs
x = x_ref[i]
y = x
y_ref[i] = y
y = y_ref[i]
y_ref[i] = jnp.sin(y * y)
sin_sq_ref = lambda x, y: (x, jnp.sin(x * x))
def for_body_reverse(i, refs):
x_ref, y_ref = refs
j = y_ref.shape[0] - i - 1
y_ref[i] = x_ref[j]
reverse_ref = lambda x, y: (x, x[::-1])
def for_body_noop(i, refs):
pass
noop_ref = lambda x, y: (x, y)
for_reference = for_loop.discharged_for_loop
class ForLoopTransformationTest(jtu.JaxTestCase):
@jtu.sample_product(
[dict(for_body_name=for_body_name, f=for_body, ref=ref,
body_shapes=body_shapes, n=nsteps)
for for_body_name, for_body, ref, body_shapes, nsteps in [
("swap", for_body_swap, swap_ref, [(4,), (4,)], 4),
("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4),
("sincos", for_body_sincos, sincos_ref, [(4,), (4,)], 4),
("sincostan", for_body_sincostan, sincostan_ref, [(4,), (4,)], 4),
("accum", for_body_accum, accum_ref, [(4,), (4,)], 3),
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
]
],
[dict(for_impl=for_impl, impl_name=impl_name)
for for_impl, impl_name in FOR_LOOP_IMPLS],
)
@jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts?
def test_for_jvp(self, f, ref, body_shapes, n, for_impl, for_body_name,
impl_name):
for_ = for_impl
rng = self.rng()
args = [rng.randn(*s) for s in body_shapes]
tol = {np.float64: 1e-12, np.float32: 1e-4}
ans = jax.jvp( lambda *args: for_( n, f, args), args, args)
ans_discharged = jax.jvp(lambda *args: for_reference(n, f, args), args, args)
expected = jax.jvp(ref, args, args)
self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol, atol=tol)
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
jtu.check_grads(partial(for_, n, f), (args,), order=2, modes=["fwd"])
@jtu.sample_product(
[dict(for_body_name=for_body_name, f=for_body, ref=ref,
body_shapes=body_shapes, n=nsteps)
for for_body_name, for_body, ref, body_shapes, nsteps in [
("swap", for_body_swap, swap_ref, [(4,), (4,)], 4),
("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4),
("sincos", for_body_sincos, sincos_ref, [(4,), (4,)], 4),
("sincostan", for_body_sincostan, sincostan_ref, [(4,), (4,)], 4),
("accum", for_body_accum, accum_ref, [(4,), (4,)], 3),
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
]
],
[dict(for_impl=for_impl, impl_name=impl_name)
for for_impl, impl_name in FOR_LOOP_IMPLS],
)
@jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts?
def test_for_linearize(self, f, ref, body_shapes, n, for_impl, for_body_name,
impl_name):
for_ = for_impl
rng = self.rng()
args = [rng.randn(*s) for s in body_shapes]
tol = {np.float64: 1e-12, np.float32: 1e-4}
ans = jax.linearize(lambda *args: for_( n, f, args), *args)[1](*args)
ans_discharged = jax.linearize(lambda *args: for_reference(n, f, args),
*args)[1](*args)
expected = jax.linearize(ref, *args)[1](*args)
self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol, atol=tol)
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
def test_for_loop_invar(self):
def f(x):
s = jnp.ones((2, 32), x.dtype)
def body(i, refs):
x_ref, y_ref = refs
y_ref[i] = s * x_ref[i] * jnp.cos(s)
# We should save `s` and `jnp.cos(s)` as residuals and not broadcast
# them.
return for_loop.for_loop(x.shape[0], body, (x, jnp.zeros_like(x)))
_, f_vjp = jax.linearize(f, jnp.ones((5, 2, 32)))
jaxpr = jax.make_jaxpr(f_vjp)(jnp.ones((5, 2, 32)))
consts = [v.aval for v in jaxpr.jaxpr.constvars
if v.aval.shape == (2, 32)]
self.assertLen(consts, 2)
def loss(A):
def step(x, _):
return jnp.matmul(A, x), None
init_x = jnp.zeros(A.shape[-1:])
last_x, _ = for_loop.scan(step, init_x, jnp.arange(10))
return jnp.sum(last_x)
A = jnp.zeros((3, 3))
# The second DUS was unnecessarily replicating A across time.
# We check XLA because _scan_impl is "underneath" the jaxpr language.
s = jax.jit(jax.grad(loss)).lower(A).as_text('hlo')
assert s.count("dynamic-update-slice(") < 2
@_for_loop_impls
def test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals(
self, for_impl):
def body(i, refs):
a_ref, b_ref, c_ref = refs
a = a_ref[i]
b = b_ref[()]
x = jnp.sin(a)
b_ref[()] = jnp.sin(b * x)
c_ref[i] = x * b
def f(a, b):
c = jnp.zeros_like(a)
_, b, c = for_impl(5, body, (a, b, c))
return b, c
a = jnp.arange(5.) + 1.
b = jnp.ones_like(a[0])
_, f_lin = jax.linearize(f, a, b)
expected_tangents = f_lin(a, b)
_, actual_tangents = jax.jvp(f, (a, b), (a, b))
np.testing.assert_allclose(actual_tangents[0], expected_tangents[0],
rtol=1e-6, atol=1e-6)
np.testing.assert_allclose(actual_tangents[1], expected_tangents[1],
rtol=1e-6, atol=1e-6)
def body2(_, refs):
# Here we use `i_ref` as a loop counter
a_ref, b_ref, c_ref, i_ref = refs
i = i_ref[()]
a = a_ref[i]
b = b_ref[()]
x = jnp.sin(a)
b_ref[()] = jnp.sin(b * x)
c_ref[i] = x * b
i_ref[()] = i + 1
def g(a, b):
c = jnp.zeros_like(a)
_, b, c, _ = for_impl(5, body2, (a, b, c, 0))
return b, c
a = jnp.arange(5.) + 1.
b = jnp.ones_like(a[0])
_, g_lin = jax.linearize(f, a, b)
expected_tangents = g_lin(a, b)
_, actual_tangents = jax.jvp(g, (a, b), (a, b))
np.testing.assert_allclose(actual_tangents[0], expected_tangents[0])
np.testing.assert_allclose(actual_tangents[1], expected_tangents[1],
rtol=1e-6)
@jtu.sample_product(
[dict(for_body_name=for_body_name, f=for_body, ref=ref,
body_shapes=body_shapes, n=nsteps)
for for_body_name, for_body, ref, body_shapes, nsteps in [
("noop", for_body_noop, noop_ref, [(4,), (4,)], 4),
("swap", for_body_swap, swap_ref, [(4,), (4,)], 4),
("swap_swap", for_body_swap_swap, swap_swap_ref, [(4,), (4,)], 4),
("sincos", for_body_sincos, sincos_ref, [(4,), (4,)], 4),
("sincostan", for_body_sincostan, sincostan_ref, [(4,), (4,)], 4),
("accum", for_body_accum, accum_ref, [(4,), (4,)], 3),
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
]
],
[dict(for_impl=for_impl, impl_name=impl_name)
for for_impl, impl_name in FOR_LOOP_IMPLS],
)
@jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts?
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_for_grad(self, f, ref, body_shapes, n, for_impl, for_body_name,
impl_name):
for_ = for_impl
rng = self.rng()
args = [rng.randn(*s) for s in body_shapes]
tol = {np.float64: 1e-12, np.float32: 1e-4}
ans = jax.grad(lambda args: for_( n, f, args)[1].sum())(args)
ans_discharged = jax.grad(
lambda args: for_reference(n, f, args)[1].sum())(args)
expected = jax.grad(lambda args: ref(*args)[1].sum())(args)
self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol,
atol=tol)
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
jtu.check_grads(lambda *args: for_(n, f, args)[1].sum(), args, order=2,
rtol=7e-3, atol=1e-2)
@jtu.skip_on_devices("gpu", "cpu") # TODO(mattjj,sharadmv, dougalm): timeouts?
@jax.legacy_prng_key('allow')
def test_grad_of_triple_nested_for_loop(self):
func = lambda x: jnp.sin(x) + 1.
@jax.jit
def f(x):
out = jnp.zeros_like(x)
def body(i, j, k, refs):
x_ref, out_ref = refs
y = func(x_ref[i, j, k])
out_ref[i, j, k] += y
return for_loop.for_loop(x.shape, body, (x, out))[1].sum()
x = random.normal(random.PRNGKey(0), (5, 4, 3))
ref = lambda x: jax.vmap(jax.vmap(jax.vmap(func)))(x).sum()
self.assertAllClose(f(x), ref(x))
jtu.check_grads(f, (x,), order=2, atol=0.1, rtol=0.1)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())