forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdebugging_primitives_test.py
872 lines (755 loc) · 26 KB
/
debugging_primitives_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import collections
import functools
import io
import textwrap
import unittest
from unittest import mock
from typing import Callable, Generator
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
from jax.config import config
from jax.experimental import maps
from jax.experimental import pjit
from jax.experimental import sharding
from jax._src import ad_checkpoint
from jax._src import debugging
from jax._src import dispatch
from jax._src import lib as jaxlib
from jax._src import test_util as jtu
import jax.numpy as jnp
import numpy as np
config.parse_flags_with_absl()
debug_print = debugging.debug_print
@contextlib.contextmanager
def capture_stdout() -> Generator[Callable[[], str], None, None]:
with mock.patch('sys.stdout', new_callable=io.StringIO) as fp:
def _read() -> str:
return fp.getvalue()
yield _read
def _format_multiline(text):
return textwrap.dedent(text).lstrip()
prev_xla_flags = None
def setUpModule():
global prev_xla_flags
# This will control the CPU devices. On TPU we always have 2 devices
prev_xla_flags = jtu.set_host_platform_device_count(2)
# Reset to previous configuration in case other test modules will be run.
def tearDownModule():
prev_xla_flags()
# TODO(sharadmv): remove jaxlib guards for TPU tests when jaxlib minimum
# version is >= 0.3.15
disabled_backends = []
if jaxlib.version < (0, 3, 15):
disabled_backends.append("tpu")
class DebugPrintTest(jtu.JaxTestCase):
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()
@jtu.skip_on_devices(*disabled_backends)
def test_simple_debug_print_works_in_eager_mode(self):
def f(x):
debug_print('x: {}', x)
with capture_stdout() as output:
f(2)
jax.effects_barrier()
self.assertEqual(output(), "x: 2\n")
@jtu.skip_on_devices(*disabled_backends)
def test_debug_print_works_with_named_format_strings(self):
def f(x):
debug_print('x: {x}', x=x)
with capture_stdout() as output:
f(2)
jax.effects_barrier()
self.assertEqual(output(), "x: 2\n")
@jtu.skip_on_devices(*disabled_backends)
def test_multiple_debug_prints_should_print_multiple_values(self):
def f(x):
debug_print('x: {x}', x=x)
debug_print('y: {y}', y=x + 1)
with capture_stdout() as output:
f(2)
jax.effects_barrier()
self.assertEqual(output(), "x: 2\ny: 3\n")
@jtu.skip_on_devices(*disabled_backends)
def test_can_stage_out_debug_print(self):
@jax.jit
def f(x):
debug_print('x: {x}', x=x)
with capture_stdout() as output:
f(2)
jax.effects_barrier()
self.assertEqual(output(), "x: 2\n")
@jtu.skip_on_devices(*disabled_backends)
def test_can_stage_out_debug_print_with_donate_argnums(self):
if jax.default_backend() not in {"gpu", "tpu"}:
raise unittest.SkipTest("Donate argnums not supported.")
def f(x, y):
debug_print('x: {x}', x=x)
return x + y
f = jax.jit(f, donate_argnums=0)
with capture_stdout() as output:
f(2, 3)
jax.effects_barrier()
self.assertEqual(output(), "x: 2\n")
@jtu.skip_on_devices(*disabled_backends)
def test_can_stage_out_ordered_print(self):
@jax.jit
def f(x):
debug_print('x: {x}', x=x, ordered=True)
with capture_stdout() as output:
f(2)
jax.effects_barrier()
self.assertEqual(output(), "x: 2\n")
@jtu.skip_on_devices(*disabled_backends)
def test_can_stage_out_ordered_print_with_donate_argnums(self):
if jax.default_backend() not in {"gpu", "tpu"}:
raise unittest.SkipTest("Donate argnums not supported.")
def f(x, y):
debug_print('x: {x}', x=x, ordered=True)
return x + y
f = jax.jit(f, donate_argnums=0)
with capture_stdout() as output:
f(2, 3)
jax.effects_barrier()
self.assertEqual(output(), "x: 2\n")
@jtu.skip_on_devices(*disabled_backends)
def test_can_stage_out_prints_with_donate_argnums(self):
if jax.default_backend() not in {"gpu", "tpu"}:
raise unittest.SkipTest("Donate argnums not supported.")
def f(x, y):
debug_print('x: {x}', x=x, ordered=True)
debug_print('x: {x}', x=x)
return x + y
f = jax.jit(f, donate_argnums=0)
with capture_stdout() as output:
f(2, 3)
jax.effects_barrier()
self.assertEqual(output(), "x: 2\nx: 2\n")
@jtu.skip_on_devices(*disabled_backends)
def test_can_double_stage_out_ordered_print(self):
@jax.jit
@jax.jit
def f(x):
debug_print('x: {x}', x=x, ordered=True)
with capture_stdout() as output:
f(2)
jax.effects_barrier()
self.assertEqual(output(), "x: 2\n")
@jtu.skip_on_devices(*disabled_backends)
def test_can_stage_out_ordered_print_with_pytree(self):
@jax.jit
def f(x):
struct = dict(foo=x)
debug_print('x: {}', struct, ordered=True)
with capture_stdout() as output:
f(np.array(2, np.int32))
jax.effects_barrier()
self.assertEqual(output(), f"x: {str(dict(foo=np.array(2, np.int32)))}\n")
class DebugPrintTransformationTest(jtu.JaxTestCase):
def test_debug_print_batching(self):
@jax.vmap
def f(x):
debug_print('hello: {}', x)
with capture_stdout() as output:
f(jnp.arange(2))
jax.effects_barrier()
self.assertEqual(output(), "hello: 0\nhello: 1\n")
def test_debug_print_batching_with_diff_axes(self):
@functools.partial(jax.vmap, in_axes=(0, 1))
def f(x, y):
debug_print('hello: {} {}', x, y)
with capture_stdout() as output:
f(jnp.arange(2), jnp.arange(2)[None])
jax.effects_barrier()
self.assertEqual(output(), "hello: 0 [0]\nhello: 1 [1]\n")
def tested_debug_print_with_nested_vmap(self):
def f(x):
debug_print('hello: {}', x)
# Call with
# [[0, 1],
# [2, 3],
# [4, 5]]
with capture_stdout() as output:
# Should print over 0-axis then 1-axis
jax.vmap(jax.vmap(f))(jnp.arange(6).reshape((3, 2)))
jax.effects_barrier()
self.assertEqual(
output(),
"hello: 0\nhello: 2\nhello: 4\nhello: 1\nhello: 3\nhello: 5\n")
with capture_stdout() as output:
# Should print over 1-axis then 0-axis
jax.vmap(jax.vmap(f, in_axes=0), in_axes=1)(jnp.arange(6).reshape((3, 2)))
jax.effects_barrier()
self.assertEqual(
output(),
"hello: 0\nhello: 1\nhello: 2\nhello: 3\nhello: 4\nhello: 5\n")
def test_debug_print_jvp_rule(self):
def f(x):
debug_print('x: {}', x)
with capture_stdout() as output:
jax.jvp(f, (1.,), (1.,))
jax.effects_barrier()
self.assertEqual(output(), "x: 1.0\n")
def test_debug_print_vjp_rule(self):
def f(x):
debug_print('x: {}', x)
with capture_stdout() as output:
jax.vjp(f, 1.)
jax.effects_barrier()
self.assertEqual(output(), "x: 1.0\n")
def test_debug_print_in_custom_jvp(self):
@jax.custom_jvp
def print_tangent(x):
return x
@print_tangent.defjvp
def _(primals, tangents):
(x,), (t,) = primals, tangents
debug_print("x_tangent: {}", t)
return x, t
def f(x):
x = jnp.sin(x)
x = print_tangent(x)
return x
with capture_stdout() as output:
x = jnp.array(1., jnp.float32)
jax.jvp(f, (x,), (x,))
jax.effects_barrier()
expected = jnp.cos(jnp.array(1., jnp.float32))
self.assertEqual(output(), f"x_tangent: {expected}\n")
@unittest.skip("doesn't work yet!") # TODO(mattjj,sharadmv)
def test_debug_print_in_custom_jvp_linearize(self):
@jax.custom_jvp
def print_tangent(x):
return x
@print_tangent.defjvp
def _(primals, tangents):
(x,), (t,) = primals, tangents
debug_print("x_tangent: {}", t)
return x, t
def f(x):
x = jnp.sin(x)
x = print_tangent(x)
return x
with capture_stdout() as output:
x = jnp.array(1., jnp.float32)
y, f_lin = jax.linearize(f, x)
jax.effects_barrier()
self.assertEqual(output(), "")
with capture_stdout() as output:
_ = f_lin(x)
jax.effects_barrier()
expected = jnp.cos(jnp.array(1., jnp.float32))
self.assertEqual(output(), f"x_tangent: {expected}\n")
def test_debug_print_grad_with_custom_vjp_rule(self):
@jax.custom_vjp
def print_grad(x):
return x
def print_grad_fwd(x):
return x, None
def print_grad_bwd(_, x_grad):
debug_print("x_grad: {}", x_grad)
return (x_grad,)
print_grad.defvjp(print_grad_fwd, print_grad_bwd)
def f(x):
debug_print("x: {}", x)
x = print_grad(x)
return jnp.sin(x)
with capture_stdout() as output:
jax.grad(f)(jnp.array(1., jnp.float32))
jax.effects_barrier()
expected = jnp.cos(jnp.array(1., jnp.float32))
self.assertEqual(output(), f"x: 1.0\nx_grad: {expected}\n")
def test_debug_print_transpose_rule(self):
def f(x):
debug_print('should never be called: {}', x)
return x
with capture_stdout() as output:
jax.linear_transpose(f, 1.)(1.)
jax.effects_barrier()
# `debug_print` should be dropped by `partial_eval` because of no
# output data-dependence.
self.assertEqual(output(), "")
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
for ordered in [False, True]))
def test_remat_of_debug_print(self, ordered):
def f_(x):
y = ad_checkpoint.checkpoint_name(x + 1., "y")
z = ad_checkpoint.checkpoint_name(y * 2., "z")
debug_print('y: {}, z: {}', y, z, ordered=ordered)
return ad_checkpoint.checkpoint_name(jnp.exp(z), "w")
# Policy that saves everything so the debug callback will be saved
f = ad_checkpoint.checkpoint(f_, policy=ad_checkpoint.everything_saveable)
with capture_stdout() as output:
jax.grad(f)(2.)
jax.effects_barrier()
# We expect the print to happen once since it gets saved and isn't
# rematerialized.
self.assertEqual(output(), "y: 3.0, z: 6.0\n")
# Policy that saves nothing so everything gets rematerialized, including the
# debug callback
f = ad_checkpoint.checkpoint(f_, policy=ad_checkpoint.nothing_saveable)
with capture_stdout() as output:
jax.grad(f)(2.)
jax.effects_barrier()
# We expect the print to happen twice since it is rematerialized.
self.assertEqual(output(), "y: 3.0, z: 6.0\n" * 2)
# Policy that does not save `z` so we will need to rematerialize the print
f = ad_checkpoint.checkpoint(
f_, policy=ad_checkpoint.save_any_names_but_these("z"))
with capture_stdout() as output:
jax.grad(f)(2.)
jax.effects_barrier()
# We expect the print to happen twice since it is rematerialized.
self.assertEqual(output(), "y: 3.0, z: 6.0\n" * 2)
def save_everything_but_these_names(*names_not_to_save):
names_not_to_save = frozenset(names_not_to_save)
def policy(prim, *_, **params):
if prim is ad_checkpoint.name_p:
return params['name'] not in names_not_to_save
return True # Save everything else
return policy
# Policy that saves everything but `y`
f = ad_checkpoint.checkpoint(
f_, policy=save_everything_but_these_names("y"))
with capture_stdout() as output:
jax.grad(f)(2.)
jax.effects_barrier()
# We expect the print to happen once because `y` is not rematerialized and
# we won't do extra materialization.
self.assertEqual(output(), "y: 3.0, z: 6.0\n")
# Policy that saves everything but `y` and `z`
f = ad_checkpoint.checkpoint(
f_, policy=save_everything_but_these_names("y", "z"))
with capture_stdout() as output:
jax.grad(f)(2.)
jax.effects_barrier()
# We expect the print to happen twice because both `y` and `z` have been
# rematerialized and we don't have to do any extra rematerialization to
# print.
self.assertEqual(output(), "y: 3.0, z: 6.0\n" * 2)
@jtu.skip_on_devices(*disabled_backends)
def test_debug_print_in_staged_out_custom_jvp(self):
@jax.jit
def f(x):
@jax.custom_jvp
def g(x):
debug_print("hello: {x}", x=x)
return x
def g_jvp(primals, tangents):
(x,), (t,) = primals, tangents
debug_print("goodbye: {x} {t}", x=x, t=t)
return x, t
g.defjvp(g_jvp)
return g(x)
with capture_stdout() as output:
f(2.)
jax.effects_barrier()
self.assertEqual(output(), "hello: 2.0\n")
with capture_stdout() as output:
jax.jvp(f, (2.,), (3.,))
jax.effects_barrier()
self.assertEqual(output(), "goodbye: 2.0 3.0\n")
@jtu.skip_on_devices(*disabled_backends)
def test_debug_print_in_staged_out_custom_vjp(self):
@jax.jit
def f(x):
@jax.custom_vjp
def g(x):
debug_print("hello: {x}", x=x)
return x
def g_fwd(x):
debug_print("hello fwd: {x}", x=x)
return x, x
def g_bwd(x, g):
debug_print("hello bwd: {x} {g}", x=x, g=g)
return (g,)
g.defvjp(fwd=g_fwd, bwd=g_bwd)
return g(x)
with capture_stdout() as output:
f(2.)
jax.effects_barrier()
self.assertEqual(output(), "hello: 2.0\n")
with capture_stdout() as output:
_, f_vjp = jax.vjp(f, 2.)
jax.effects_barrier()
self.assertEqual(output(), "hello fwd: 2.0\n")
with capture_stdout() as output:
f_vjp(3.0)
jax.effects_barrier()
self.assertEqual(output(), "hello bwd: 2.0 3.0\n")
class DebugPrintControlFlowTest(jtu.JaxTestCase):
def _assertLinesEqual(self, text1, text2):
def _count(lines):
return collections.Counter(lines)
self.assertDictEqual(_count(text1.split("\n")), _count(text2.split("\n")))
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
for ordered in [False, True]))
@jtu.skip_on_devices(*disabled_backends)
def test_can_print_inside_scan(self, ordered):
def f(xs):
def _body(carry, x):
debug_print("carry: {carry}, x: {x}", carry=carry, x=x, ordered=ordered)
return carry + 1, x + 1
return lax.scan(_body, 2, xs)
with capture_stdout() as output:
f(jnp.arange(2))
jax.effects_barrier()
self.assertEqual(
output(),
_format_multiline("""
carry: 2, x: 0
carry: 3, x: 1
"""))
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
for ordered in [False, True]))
@jtu.skip_on_devices(*disabled_backends)
def test_can_print_inside_for_loop(self, ordered):
def f(x):
def _body(i, x):
debug_print("i: {i}", i=i, ordered=ordered)
debug_print("x: {x}", x=x, ordered=ordered)
return x + 1
return lax.fori_loop(0, 5, _body, x)
with capture_stdout() as output:
f(2)
jax.effects_barrier()
expected = _format_multiline("""
i: 0
x: 2
i: 1
x: 3
i: 2
x: 4
i: 3
x: 5
i: 4
x: 6
""")
if ordered:
self.assertEqual(output(), expected)
else:
self._assertLinesEqual(output(), expected)
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
for ordered in [False, True]))
@jtu.skip_on_devices(*disabled_backends)
def test_can_print_inside_while_loop_body(self, ordered):
def f(x):
def _cond(x):
return x < 10
def _body(x):
debug_print("x: {x}", x=x, ordered=ordered)
return x + 1
return lax.while_loop(_cond, _body, x)
with capture_stdout() as output:
f(5)
jax.effects_barrier()
self.assertEqual(output(), _format_multiline("""
x: 5
x: 6
x: 7
x: 8
x: 9
"""))
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
for ordered in [False, True]))
@jtu.skip_on_devices(*disabled_backends)
def test_can_print_inside_while_loop_cond(self, ordered):
def f(x):
def _cond(x):
debug_print("x: {x}", x=x, ordered=ordered)
return x < 10
def _body(x):
return x + 1
return lax.while_loop(_cond, _body, x)
with capture_stdout() as output:
f(5)
jax.effects_barrier()
self.assertEqual(output(), _format_multiline("""
x: 5
x: 6
x: 7
x: 8
x: 9
x: 10
"""))
with capture_stdout() as output:
f(10)
jax.effects_barrier()
# Should run the cond once
self.assertEqual(output(), _format_multiline("""
x: 10
"""))
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
for ordered in [False, True]))
@jtu.skip_on_devices(*disabled_backends)
def test_can_print_inside_cond(self, ordered):
def f(x):
def true_fun(x):
debug_print("true: {}", x, ordered=ordered)
return x
def false_fun(x):
debug_print("false: {}", x, ordered=ordered)
return x
return lax.cond(x < 5, true_fun, false_fun, x)
with capture_stdout() as output:
f(5)
jax.effects_barrier()
self.assertEqual(output(), _format_multiline("""
false: 5
"""))
with capture_stdout() as output:
f(4)
jax.effects_barrier()
self.assertEqual(output(), _format_multiline("""
true: 4
"""))
@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name="_ordered" if ordered else "", ordered=ordered)
for ordered in [False, True]))
@jtu.skip_on_devices(*disabled_backends)
def test_can_print_inside_switch(self, ordered):
def f(x):
def b1(x):
debug_print("b1: {}", x, ordered=ordered)
return x
def b2(x):
debug_print("b2: {}", x, ordered=ordered)
return x
def b3(x):
debug_print("b3: {}", x, ordered=ordered)
return x
return lax.switch(x, (b1, b2, b3), x)
with capture_stdout() as output:
f(0)
jax.effects_barrier()
self.assertEqual(output(), _format_multiline("""
b1: 0
"""))
with capture_stdout() as output:
f(1)
jax.effects_barrier()
self.assertEqual(output(), _format_multiline("""
b2: 1
"""))
with capture_stdout() as output:
f(2)
jax.effects_barrier()
self.assertEqual(output(), _format_multiline("""
b3: 2
"""))
class DebugPrintParallelTest(jtu.JaxTestCase):
def _assertLinesEqual(self, text1, text2):
def _count(lines):
return collections.Counter(lines)
self.assertDictEqual(_count(text1.split("\n")), _count(text2.split("\n")))
@jtu.skip_on_devices(*disabled_backends)
def test_ordered_print_not_supported_in_pmap(self):
@jax.pmap
def f(x):
debug_print("{}", x, ordered=True)
with self.assertRaisesRegex(
ValueError, "Ordered effects not supported in `pmap`."):
f(jnp.arange(jax.local_device_count()))
@jtu.skip_on_devices(*disabled_backends)
def test_unordered_print_works_in_pmap(self):
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
@jax.pmap
def f(x):
debug_print("hello: {}", x, ordered=False)
with capture_stdout() as output:
f(jnp.arange(jax.local_device_count()))
jax.effects_barrier()
lines = [f"hello: {i}\n" for i in range(jax.local_device_count())]
self._assertLinesEqual(output(), "".join(lines))
@jax.pmap
def f2(x):
debug_print('hello: {}', x)
debug_print('hello: {}', x + 2)
with capture_stdout() as output:
f2(jnp.arange(2))
jax.effects_barrier()
self._assertLinesEqual(output(), "hello: 0\nhello: 1\nhello: 2\nhello: 3\n")
@jtu.skip_on_devices(*disabled_backends)
def test_unordered_print_with_pjit(self):
if jax.default_backend() in {"cpu", "gpu"} and jaxlib.version < (0, 3, 16):
raise unittest.SkipTest("`pjit` of callback not supported.")
def f(x):
debug_print("{}", x, ordered=False)
return x
mesh = maps.Mesh(np.array(jax.devices()), ['dev'])
if config.jax_array:
spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec('dev'))
out_spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec())
else:
spec = pjit.PartitionSpec('dev')
out_spec = pjit.PartitionSpec()
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=spec)
with mesh:
with capture_stdout() as output:
f(np.arange(8, dtype=jnp.int32))
jax.effects_barrier()
self.assertEqual(output(), "[0 1 2 3 4 5 6 7]\n")
def f2(x):
y = x.dot(x)
debug_print("{}", y, ordered=False)
return y
f2 = pjit.pjit(f2, in_axis_resources=spec, out_axis_resources=out_spec)
with maps.Mesh(np.array(jax.devices()), ['dev']):
with capture_stdout() as output:
f2(np.arange(8, dtype=jnp.int32))
jax.effects_barrier()
self.assertEqual(output(), "140\n")
@jtu.skip_on_devices(*disabled_backends)
def test_unordered_print_of_pjit_of_while(self):
if (jax.default_backend() in {"cpu", "gpu"}
and jaxlib.xla_extension_version < 81):
raise unittest.SkipTest("`pjit` of callback not supported.")
def f(x):
def cond(carry):
i, *_ = carry
return i < 5
def body(carry):
i, x = carry
debug_print("{}", x, ordered=False)
x = x + 1
return (i + 1, x)
return lax.while_loop(cond, body, (0, x))[1]
mesh = maps.Mesh(np.array(jax.devices()), ['dev'])
if config.jax_array:
spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec('dev'))
else:
spec = pjit.PartitionSpec('dev')
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=spec)
with mesh:
with capture_stdout() as output:
f(np.arange(8, dtype=jnp.int32))
jax.effects_barrier()
self.assertEqual(output(),
"[0 1 2 3 4 5 6 7]\n"
"[1 2 3 4 5 6 7 8]\n"
"[2 3 4 5 6 7 8 9]\n"
"[ 3 4 5 6 7 8 9 10]\n"
"[ 4 5 6 7 8 9 10 11]\n")
@jtu.skip_on_devices(*disabled_backends)
def test_unordered_print_of_pjit_of_xmap(self):
# TODO(https://github.com/google/jax/issues/12016): Make xmap work properly
# with Arrays of different
# sharding.
if config.jax_array:
raise unittest.SkipTest('Does not work with Array.')
if (jax.default_backend() in {"cpu", "gpu"}
and jaxlib.xla_extension_version < 81):
raise unittest.SkipTest("`pjit` of callback not supported.")
def f(x):
def foo(x):
idx = lax.axis_index('foo')
debug_print("{idx}: {x}", idx=idx, x=x)
return jnp.mean(x, axis=['foo'])
out = maps.xmap(foo, in_axes=['foo'], out_axes=[...])(x)
debug_print("Out: {}", out)
return out
f = pjit.pjit(f, in_axis_resources=pjit.PartitionSpec('dev'),
out_axis_resources=pjit.PartitionSpec())
with maps.Mesh(np.array(jax.devices()), ['dev']):
with capture_stdout() as output:
f(jnp.arange(8, dtype=jnp.int32) * 2)
lines = ["0: 0", "1: 2", "2: 4", "3: 6", "4: 8", "5: 10", "6: 12",
"7: 14", "Out: 7.0", ""]
jax.effects_barrier()
self._assertLinesEqual(output(), "\n".join(lines))
@jtu.skip_on_devices(*disabled_backends)
def test_unordered_print_with_xmap(self):
def f(x):
debug_print("{}", x, ordered=False)
f = maps.xmap(f, in_axes=['a'], out_axes=None, backend='cpu',
axis_resources={'a': 'dev'})
with maps.Mesh(np.array(jax.devices()), ['dev']):
with capture_stdout() as output:
f(np.arange(40))
jax.effects_barrier()
lines = [f"{i}\n" for i in range(40)]
self._assertLinesEqual(output(), "".join(lines))
@jtu.skip_on_devices(*disabled_backends)
def test_unordered_print_works_in_pmap_of_while(self):
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
@jax.pmap
def f(x):
def cond(x):
return x < 3
def body(x):
debug_print("hello: {}", x, ordered=False)
return x + 1
return lax.while_loop(cond, body, x)
with capture_stdout() as output:
f(jnp.arange(2))
jax.effects_barrier()
self._assertLinesEqual(
output(), "hello: 0\nhello: 1\nhello: 2\n"
"hello: 1\nhello: 2\n")
@jtu.skip_on_devices(*disabled_backends)
def test_incorrectly_formatted_string(self):
@jax.jit
def f(x):
debug_print("hello: {x}", x)
return x
with self.assertRaises(KeyError):
f(jnp.arange(2))
jax.effects_barrier()
@jax.jit
def f(x):
debug_print("hello: {}", x=x)
return x
with self.assertRaises(IndexError):
f(jnp.arange(2))
jax.effects_barrier()
@jtu.skip_on_devices(*disabled_backends)
def test_format_string_errors_with_unused_args(self):
@jax.jit
def f(x):
debug_print("hello: {x}", x=x, y=x)
return x
with self.assertRaisesRegex(ValueError, "Unused keyword arguments"):
f(jnp.arange(2))
jax.effects_barrier()
@jax.jit
def g(x):
debug_print("hello", x)
return x
with self.assertRaisesRegex(ValueError, "Unused positional arguments"):
g(jnp.arange(2))
jax.effects_barrier()
@jtu.skip_on_devices(*disabled_backends)
def test_accidental_fstring(self):
@jax.jit
def f(x):
debug_print(f"hello: {x}", x=x)
return x
with self.assertRaisesRegex(ValueError, "You may be passing an f-string"):
f(jnp.arange(2))
jax.effects_barrier()
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())