forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pmap_test.py
3282 lines (2698 loc) · 118 KB
/
pmap_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import itertools as it
import gc
import math
import os
from random import shuffle
import re
from typing import Union, cast
import unittest
from unittest import SkipTest
import weakref
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr,
linearize, device_put)
from jax import lax
from jax import random
from jax import tree_util
from jax.ad_checkpoint import checkpoint as new_checkpoint
import jax.numpy as jnp
from jax._src import api as src_api
from jax._src import array
from jax._src import core
from jax._src import config
from jax._src import sharding_impls
from jax._src import sharding_specs
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.internal_test_util import lax_test_util
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.lax import parallel
from jax._src.lib import xla_extension
from jax._src.util import safe_map, safe_zip
config.parse_flags_with_absl()
prev_xla_flags = None
compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]
def all_bdims(*shapes, pmap):
bdims = (it.chain([cast(Union[int, None], None)], range(len(shape) + 1))
for shape in shapes)
return (t for t in it.product(*bdims) if not all(e is None for e in t))
def out_bdims(shape, pmap):
return (d[0] for d in all_bdims(shape, pmap=pmap) if d[0] is not None)
def add_bdim(bdim_size, bdim, shape):
shape = list(shape)
if bdim is not None:
shape.insert(bdim, bdim_size)
return tuple(shape)
def slicer(x, bdim):
if bdim is None:
return lambda _: x
else:
return lambda i: lax.index_in_dim(x, i, bdim, keepdims=False)
def args_slicer(args, bdims):
slicers = safe_map(slicer, args, bdims)
return lambda i: [sl(i) for sl in slicers]
# Run all tests with 8 CPU devices.
def setUpModule():
global prev_xla_flags
prev_xla_flags = os.getenv("XLA_FLAGS")
flags_str = prev_xla_flags or ""
# Don't override user-specified device count, or other XLA flags.
if "xla_force_host_platform_device_count" not in flags_str:
os.environ["XLA_FLAGS"] = (flags_str +
" --xla_force_host_platform_device_count=8")
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
# Reset to previous configuration in case other test modules will be run.
def tearDownModule():
if prev_xla_flags is None:
del os.environ["XLA_FLAGS"]
else:
os.environ["XLA_FLAGS"] = prev_xla_flags
xla_bridge.get_backend.cache_clear()
ignore_jit_of_pmap_warning = partial(
jtu.ignore_warning, message=".*jit-of-pmap.*")
ignore_xmap_warning = partial(
jtu.ignore_warning, message=".*is an experimental.*")
def create_input_array_for_pmap(input_shape, in_axes=0, input_data=None,
devices=None, sharded_dim_size=None):
if input_data is None:
input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
sharding_spec = sharding_specs.create_pmap_sharding_spec(
input_shape, in_axes, sharded_dim_size)
if devices is None:
devices = jax.devices()
pmap_sharding = jax.sharding.PmapSharding(np.array(devices), sharding_spec)
return array.make_array_from_callback(
input_shape, pmap_sharding, lambda idx: input_data[idx]), input_data
@jtu.pytest_mark_if_available('multiaccelerator')
@jtu.with_config(jax_legacy_prng_key="allow")
class PythonPmapTest(jtu.JaxTestCase):
@property
def pmap(self):
return src_api.pmap
def testDeviceBufferToArray(self):
sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2)))
# Changed in https://github.com/google/jax/pull/10584 not to access
# sda.device_buffers, which isn't supported, and instead ensure fast slices
# of the arrays returned by pmap are set up correctly.
# buf = sda.device_buffers[-1]
buf = sda[-1]
view = jnp.array(buf, copy=False)
self.assertArraysEqual(sda[-1], view)
self.assertSetEqual(buf.devices(), view.devices())
self.assertEqual(buf.unsafe_buffer_pointer(), view.unsafe_buffer_pointer())
copy = jnp.array(buf, copy=True)
self.assertArraysEqual(sda[-1], copy)
self.assertSetEqual(buf.devices(), copy.devices())
self.assertNotEqual(buf.unsafe_buffer_pointer(), copy.unsafe_buffer_pointer())
def _getMeshShape(self, device_mesh_shape):
device_count = jax.device_count()
if any(size == -1 for size in device_mesh_shape):
try:
return np.arange(device_count).reshape(device_mesh_shape).shape
except ValueError as err:
msg = "device mesh shape {} not compatible with device count {}"
raise SkipTest(msg.format(device_mesh_shape, device_count)) from err
else:
if device_count % math.prod(device_mesh_shape):
msg = "device mesh size {} does not divide available device count {}"
raise SkipTest(msg.format(math.prod(device_mesh_shape), device_count))
else:
return device_mesh_shape
def testBasic(self):
f = self.pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.sum(x, 0)
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testDefaultDeviceOrdering(self):
# Users rely on the fact that the default order of jax.devices() matches
# the default order of pmap for single-host jobs.
device_order = jax.devices()
pmap_sharding = pmap(lambda x: x)(np.arange(jax.device_count())).sharding
if config.pmap_shmap_merge.value:
self.assertListEqual(device_order, pmap_sharding._device_assignment)
else:
self.assertListEqual(device_order, pmap_sharding.devices.tolist())
def testLowerCompile(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = f(x)
lowered = f.lower(x)
compiled = lowered.compile()
ans = compiled(x)
self.assertAllClose(ans, expected)
# It's a pair of: (positional args, as a tuple of their structures, kwargs).
for obj in [lowered, compiled]:
self.assertFalse(obj._no_kwargs)
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0,), {}))[1])
self.assertEqual(obj.in_avals, ((core.ShapedArray(x.shape, x.dtype),), {}))
def testLowerCompileInTreeMismatch(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f_exe = f.lower(x).compile()
self.assertRaisesRegex(
TypeError, "function compiled for .*, called with .*",
lambda: f_exe([x]))
def testLowerCompileTrivial(self):
f = self.pmap(lambda x: x, axis_name='i')
x = np.arange(jax.device_count(), dtype=np.float32)
expected = f(x)
f_exe = f.lower(x).compile()
ans = f_exe(x)
self.assertAllClose(ans, expected)
def testLowerCompileTrivialInTreeMismatch(self):
f = self.pmap(lambda x: x, axis_name='i')
x = np.arange(jax.device_count(), dtype=np.float32)
f_exe = f.lower(x).compile()
self.assertRaisesRegex(
TypeError, "function compiled for .*, called with .*",
lambda: f_exe([x]))
def testLowerCompileArgTypeMismatch(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=int).reshape(shape)
x_f32 = x.astype(jnp.float32)
x_i32 = x.astype(jnp.int32)
f_exe = f.lower(x_f32).compile()
self.assertRaisesRegex(
TypeError,
r"Argument types differ .*"
r"The mismatches are:\n"
r"Argument 'x' compiled with.*float32.*and called with.*int32.*",
lambda: f_exe(x_i32))
def testLowerCompileMultiArg(self):
f = self.pmap(lambda x, y: x - lax.pmean(y, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = y = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = f(x, y)
f_exe = f.lower(x, y).compile()
ans = f_exe(x, y)
self.assertAllClose(ans, expected)
def testLowerCompileTrivialMultiArg(self):
f = self.pmap(lambda x, y: (x, y), axis_name='i')
x = y = np.arange(jax.device_count(), dtype=np.float32)
expected = f(x, y)
f_exe = f.lower(x, y).compile()
ans = f_exe(x, y)
self.assertAllClose(ans, expected)
def testLowerAsText(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x)
self.assertIsInstance(f.as_text(), str)
self.assertIsInstance(f.as_text(dialect='hlo'), str)
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
def testLowerCompilerIR(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x)
self.assertIsNotNone(f.compiler_ir())
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
def testLowerCompileCompilerIR(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x).compile()
self.assertIsNotNone(f.runtime_executable())
def testLowerCompileAsText(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x).compile()
self.assertIsInstance(f.as_text(), (str, type(None)))
def testLowerCostAnalysis(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x)
f.cost_analysis() # doesn't raise
def testLowerCompileCostAnalysis(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x).compile()
f.cost_analysis() # doesn't raise
def testLowerCompileMemoryAnalysis(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x).compile()
f.memory_analysis() # doesn't raise
def testLowerCompileExecutable(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x).compile()
self.assertIsNotNone(f.runtime_executable())
def test_jit_lower_compile_with_compiler_options(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
lowered = f.lower(x)
lowered.compile( # doesn't crash
compiler_options={"xla_embed_ir_in_executable": True})
def test_jit_lower_compile_with_compiler_options_invalid(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
lowered = f.lower(x)
self.assertRaisesRegex(
xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'",
lambda: lowered.compile(
compiler_options={"invalid_key": "invalid_value"}))
self.assertRaisesRegex(
xla_extension.XlaRuntimeError, "is not a valid bool value.",
lambda: lowered.compile(
compiler_options={"xla_embed_ir_in_executable": "invalid_value"}))
def test_pmap_replicated_copy(self):
# https://github.com/google/jax/issues/17690
inp = jnp.arange(jax.device_count())
x = jax.pmap(lambda x: x, in_axes=0, out_axes=None)(inp)
out = jnp.copy(x)
self.assertIsInstance(out.sharding, jax.sharding.SingleDeviceSharding)
self.assertArraysEqual(out, inp[0])
def test_jit_lower_compile_with_compiler_options_multiple(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
lowered = f.lower(x)
l1 = lowered.compile()
l2 = lowered.compile(
compiler_options={"xla_embed_ir_in_executable": True})
l3 = lowered.compile(
compiler_options={"xla_embed_ir_in_executable": False})
# Ideally we could test that these objects are different only in
# that they respect the different options. Object identity is a
# heuristic proxy for that.
self.assertTrue(l1 is not l2)
self.assertTrue(l1 is not l3)
self.assertTrue(l2 is not l3)
# We should still error on invalid options after some valid compiles
self.assertRaisesRegex(
xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'",
lambda: lowered.compile(
compiler_options={"invalid_key": "invalid_value"}))
def testLowerShapedArray(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
x_shape = core.ShapedArray(x.shape, x.dtype)
self.assertAllClose(f.lower(x_shape).compile()(x), f(x))
def testLowerHasReplicaAttributes(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
num_devices = jax.device_count()
shape = (num_devices, 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
lowered = f.lower(x)
hlo = lowered.as_text("stablehlo")
self.assertIn(f"mhlo.num_replicas = {num_devices}", hlo)
self.assertIn("mhlo.num_partitions = 1", hlo)
def testMean(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.broadcast_to(np.mean(x, 0), x.shape)
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testGather(self):
f = self.pmap(lambda x: lax.all_gather(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = np.array([x] * jax.device_count())
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testGatherBool(self):
f = self.pmap(lambda x: lax.all_gather(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
x = (x % 2).astype(np.bool_)
expected = np.array([x] * jax.device_count())
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testGatherNegativeAxis(self):
f = self.pmap(lambda x: lax.all_gather(x, 'i', axis=-1), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = np.array([x.T] * jax.device_count())
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testGatherTiled(self):
f = self.pmap(lambda x: lax.all_gather(x, 'i', tiled=True), axis_name='i')
device_count = jax.device_count()
shape = (device_count, 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = np.array([x] * device_count).reshape(device_count, -1)
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testGatherTiledNegativeAxis(self):
f = self.pmap(lambda x: lax.all_gather(x, 'i', tiled=True, axis=-1),
axis_name='i')
device_count = jax.device_count()
shape = (device_count, 4, 3)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = np.array([x.transpose(1, 0, 2).reshape(4, -1)] * device_count)
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
@parameterized.named_parameters([
('Gather', lax.all_gather),
('ReduceScatter', lax.psum_scatter)
])
def testVmapOf(self, prim):
f = self.pmap(partial(prim, axis_name='i'), axis_name='i')
device_count = jax.device_count()
shape = (4, device_count, device_count)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
self.assertAllClose(vmap(f)(x), jnp.stack([f(xs) for xs in x], axis=0))
def testReduceScatter(self):
f = self.pmap(lambda x: lax.psum_scatter(x, 'i'), axis_name='i')
device_count = jax.device_count()
shape = (device_count, device_count)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = np.sum(x, axis=0)
ans = f(x)
for i, actual in enumerate(ans):
self.assertAllClose(actual, expected[i])
def testReduceScatterTiled(self):
f = self.pmap(lambda x: lax.psum_scatter(x, 'i', tiled=True), axis_name='i')
device_count = jax.device_count()
shape = (device_count, 4 * device_count)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = np.sum(x, axis=0)
ans = f(x)
scatter_len = len(expected) // device_count
for i, actual in enumerate(ans):
self.assertAllClose(actual,
expected[i * scatter_len:(i + 1) * scatter_len])
def testReduceScatterReplicaGroupsTiled(self):
replicas = jax.device_count()
if replicas % 2 != 0:
raise SkipTest
axis_index_groups = [[i for i in range(jax.device_count()) if i % 2 == 0],
[i for i in range(jax.device_count()) if i % 2 != 0]]
f = lambda x: lax.psum_scatter(
x, 'i', axis_index_groups=axis_index_groups, tiled=True)
f = self.pmap(f, axis_name='i')
shape = (replicas, 4 * replicas)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
ans = f(x)
group_1_result = np.sum(x[0::2,:], axis=0)
group_2_result = np.sum(x[1::2,:], axis=0)
# the result is scattered over (replicas // 2) devices
scatter_len = len(group_1_result) * 2 // replicas
for i, actual in enumerate(ans):
expected = group_1_result if i % 2 == 0 else group_2_result
self.assertAllClose(
actual, expected[i // 2 * scatter_len:(i // 2 + 1) * scatter_len])
def testTrees(self):
ptranspose = lambda x, axis_name: lax.all_to_all(x, axis_name, 0, 0)
def protate(x, axis_name):
n = lax.psum(1, axis_name)
return lax.ppermute(x, axis_name, [(i, (i + 1) % n) for i in range(n)])
tree_f = lambda f: partial(tree_util.tree_map, f)
jax_f = lambda p: self.pmap(lambda x: p(x, 'i'), 'i')
np_f = lambda p: tree_f(lambda x: np.broadcast_to(p(x, 0), x.shape))
np_transpose = tree_f(np.transpose)
np_rotate = tree_f(lambda x: np.concatenate([x[-1:], x[:-1]]))
n = jax.device_count()
x = {'a': np.arange(1 * n * n, 2 * n * n).reshape([n, n]),
'b': np.arange(2 * n * n, 3 * n * n).reshape([n, n]),
'c': np.arange(4 * n * n, 5 * n * n).reshape([n, n])}
assert_allclose = partial(tree_util.tree_map,
partial(self.assertAllClose, check_dtypes=False))
assert_allclose(jax_f(lax.pmax)(x), np_f(np.max)(x))
assert_allclose(jax_f(lax.pmin)(x), np_f(np.min)(x))
assert_allclose(jax_f(lax.psum)(x), np_f(np.sum)(x))
assert_allclose(jax_f(lax.pmean)(x), np_f(np.mean)(x))
assert_allclose(jax_f(ptranspose)(x), np_transpose(x))
assert_allclose(jax_f(protate)(x), np_rotate(x))
def testCollectivesWithTreesOfDifferentDtypes(self):
n = len(jax.devices())
x = {'a': np.arange(1 * n * n, 2 * n * n, dtype=np.float32).reshape([n, n]),
'b': np.arange(2 * n * n, 3 * n * n, dtype=np.int32).reshape([n, n]),
'c': np.arange(4 * n * n, 5 * n * n, dtype=np.float32).reshape([n, n]),
'd': np.arange(6 * n * n, 7 * n * n, dtype=np.int32).reshape([n, n])}
tree_f = lambda f: partial(tree_util.tree_map, f)
jax_f = lambda p: self.pmap(lambda x: p(x, 'i'), 'i')
np_f = lambda p: tree_f(lambda x: np.broadcast_to(p(x, 0), x.shape))
assert_allclose = partial(tree_util.tree_map,
partial(self.assertAllClose, check_dtypes=False))
assert_allclose(jax_f(lax.pmax)(x), np_f(np.max)(x))
assert_allclose(jax_f(lax.pmin)(x), np_f(np.min)(x))
assert_allclose(jax_f(lax.psum)(x), np_f(np.sum)(x))
assert_allclose(jax_f(lax.pmean)(x), np_f(np.mean)(x))
def testComplexPsum(self):
f = self.pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4 * 2)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape).view(np.complex64)
expected = x - np.sum(x, 0)
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
@jtu.sample_product(
split_axis=list(range(2)),
concat_axis=list(range(2)),
dtype=lax_test_util.all_dtypes,
)
def testAllToAll(self, split_axis, concat_axis, dtype):
pmap_in_axis = 0
shape = (jax.device_count(),) * 3
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
@partial(self.pmap, axis_name='i')
def f(x):
return lax.all_to_all(x, 'i', split_axis, concat_axis)
y = f(x)
if pmap_in_axis <= split_axis:
split_axis += 1
ref = jnp.moveaxis(x, (pmap_in_axis, split_axis),
(concat_axis + 1, 0))
self.assertAllClose(y, ref)
@parameterized.named_parameters(
{"testcase_name": f"_split={split_axis}_concat={concat_axis}",
"split_axis": split_axis, "concat_axis": concat_axis}
for split_axis, concat_axis in it.product(range(2), range(2)))
def testAllToAllSplitAxis(self, split_axis, concat_axis):
if jax.device_count() < 4:
raise SkipTest("test requires at least four devices")
pmap_in_axis = 0
shape = (4, 4, 4)
x = np.arange(math.prod(shape)).reshape(shape)
@partial(self.pmap, axis_name='i')
@partial(self.pmap, axis_name='j')
def f(x):
return lax.all_to_all(x, ('i', 'j'), split_axis, concat_axis)
unroll_shape = (2, 2, *shape[1:])
x_unroll = x.reshape(unroll_shape)
y_unroll = f(x_unroll)
y = y_unroll.reshape(shape)
if pmap_in_axis <= split_axis:
split_axis += 1
ref = jnp.moveaxis(x, (pmap_in_axis, split_axis),
(concat_axis + 1, 0))
self.assertAllClose(y, ref)
def testNestedPmapAxisSwap(self):
# Regression test for https://github.com/google/jax/issues/5757
if jax.device_count() < 8:
raise SkipTest("test requires at least 8 devices")
f = jax.pmap(jax.pmap(lambda x: x, in_axes=1, out_axes=0), in_axes=0,
out_axes=0)
A = jnp.ones((2, 4, 3))
self.assertAllClose(A.transpose((0, 2, 1)), f(A))
def testNestedBasic(self):
f = lambda x: lax.psum(lax.psum(x, 'i'), 'j')
f = self.pmap(self.pmap(f, 'i'), 'j')
def sum_and_broadcast(x, axis):
return np.repeat(np.sum(x, axis, keepdims=True), x.shape[axis], axis)
shape = (jax.device_count(), 1, 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
ans = f(x)
expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1)
self.assertAllClose(ans, expected, check_dtypes=False)
def testMismatchedAxisSizes(self):
n = jax.device_count()
f = self.pmap(lambda x, y: x + y)
self.assertRaisesRegex(
ValueError,
"pmap got inconsistent sizes for array axes to be mapped",
lambda: f(self.rng().randn(n), self.rng().randn(n - 1)))
def testInAxesPyTreePrefixMismatchError(self):
x = jnp.array([3.14])
f = self.pmap(lambda x, y: x, in_axes=((0, 0, 0), 0))
with self.assertRaisesRegex(ValueError, re.escape("pmap in_axes[0][0]")):
f((x, x), x)
def testInAxesPyTreePrefixMismatchErrorKwargs(self):
x = jnp.array([3.14])
f = self.pmap(lambda x, y: x, in_axes=((0, 0), 0))
with self.assertRaisesRegex(
ValueError, re.escape("each argument passed by keyword is mapped")):
f(x=(x, x), y=x)
def testOutAxesPyTreePrefixMismatchError(self):
x = jnp.array([3.14])
f = jax.pmap(lambda x, y: ((x, x), x), out_axes=((0, 0, 0), 0))
with self.assertRaisesRegex(ValueError, re.escape("pmap out_axes[0]")):
f(x, x)
@parameterized.named_parameters(
{"testcase_name": f"_mesh={device_mesh_shape}".replace(" ", ""),
"device_mesh_shape": device_mesh_shape}
for device_mesh_shape in [(1, 1), (2, -1), (-1, 2)])
def testNestedShardingAndStacking(self, device_mesh_shape):
mesh_shape = self._getMeshShape(device_mesh_shape)
f = lambda x: x
f = self.pmap(self.pmap(f, 'i'), 'j')
shape = mesh_shape + (4,)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
ans = f(x)
expected = x
self.assertEqual(ans.shape, expected.shape)
self.assertAllClose(ans, expected, check_dtypes=False)
def testPartiallyMapped(self):
f = self.pmap(lambda x, y: x, in_axes=(None, 0))
g = self.pmap(lambda x, y: x - lax.psum(y, 'i'), axis_name='i', in_axes=(None, 0))
mesh_shape = (jax.device_count(),)
shape = mesh_shape + (4,)
x = np.array(3., dtype=np.float32)
y = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f_expected = np.broadcast_to(x, mesh_shape)
f_ans = f(x, y)
self.assertAllClose(f_ans, f_expected)
self.assertIsInstance(f_ans, array.ArrayImpl)
sharding_spec = f_ans.sharding.sharding_spec
# the output is actually replicated (has the same values in each device buffer)
# but out_axes is implicitly 0, so we shouldn't have replication in the
# sharding spec.
self.assertEmpty([a for a in sharding_spec.mesh_mapping
if isinstance(a, pxla.Replicated)])
g_expected = np.broadcast_to(x - np.sum(y, 0, keepdims=True), shape)
g_ans = g(x, y)
self.assertAllClose(g_ans, g_expected)
self.assertIsInstance(g_ans, array.ArrayImpl)
sharding_spec = g_ans.sharding.sharding_spec
self.assertEmpty([a for a in sharding_spec.mesh_mapping
if isinstance(a, pxla.Replicated)])
@parameterized.named_parameters(
{"testcase_name": f"_mesh={device_mesh_shape}".replace(" ", ""),
"device_mesh_shape": device_mesh_shape}
for device_mesh_shape in [(1, 1), (2, -1), (-1, 2)])
def testPartiallyMappedNested(self, device_mesh_shape):
mesh_shape = self._getMeshShape(device_mesh_shape)
f = self.pmap(lambda x, y: x - lax.psum(y, 'i'), axis_name='i', in_axes=(None, 0))
f = self.pmap(f, axis_name='j', in_axes=(None, 0))
x = 3.
y = np.arange(math.prod(mesh_shape), dtype=np.float32).reshape(mesh_shape)
expected = np.broadcast_to(x - np.sum(y, 1, keepdims=True), mesh_shape)
ans = f(x, y)
self.assertAllClose(ans, expected, check_dtypes=False)
def testJvpAndPartialEval(self):
@partial(self.pmap, axis_name='i')
def f(x):
return jnp.sin(x)
def splitjvp(x):
_, jvp = linearize(f, x)
return jvp(jnp.ones_like(x))
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = np.cos(x)
ans = splitjvp(x)
self.assertAllClose(ans, expected, check_dtypes=False)
make_jaxpr(splitjvp)(x) # doesn't crash
def testGradBasic(self):
@partial(self.pmap, axis_name='i')
def f(x):
return jnp.sin(x)
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
ans = grad(lambda x: jnp.sum(jnp.sin(x)))(x)
expected = grad(lambda x: jnp.sum(f(x)))(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testGradOfPsum(self):
@partial(self.pmap, axis_name='i')
def f(x):
return lax.psum(x, axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
jtu.check_grads(f, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, eps=1.)
def testGradOfJvp(self):
@partial(self.pmap, axis_name='i')
def f(x):
return jnp.sin(x)
def splitjvp(x):
_, jvp = linearize(f, x)
return jvp(jnp.ones_like(x))
fun = lambda x: jnp.sum(jvp(jnp.sin, (x,), (jnp.ones_like(x),))[1])
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
ans = grad(lambda x: jnp.sum(splitjvp(x)))(x)
expected = grad(fun)(x)
self.assertAllClose(ans, expected)
def testTwoArgsGrad(self):
def f(x, y):
return lax.psum(5. * jnp.cos(x) * jnp.sin(y), 'i')
f = self.pmap(f, 'i')
def g(x, y):
tot = jnp.sum(5. * jnp.cos(x) * jnp.sin(y))
return tot * jnp.ones_like(x) # broadcast to map like pjit does
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
y = 4 + x
ans = grad(lambda x, y: jnp.sum(g(x, y)))(x, y)
expected = grad(lambda x, y: jnp.sum(g(x, y)))(x, y)
self.assertAllClose(ans, expected, check_dtypes=False)
@parameterized.named_parameters(
{"testcase_name": f"_mesh={device_mesh_shape}".replace(" ", ""),
"device_mesh_shape": device_mesh_shape}
for device_mesh_shape in [(1, 1), (2, -1), (-1, 2)])
def testNestedWithClosure(self, device_mesh_shape):
mesh_shape = self._getMeshShape(device_mesh_shape)
@partial(self.pmap, axis_name='i')
def test_fun(x):
y = jnp.sum(jnp.sin(x))
@partial(self.pmap, axis_name='j')
def g(z):
return 3. * jnp.exp(jnp.sin(x).sum() * jnp.cos(y) * jnp.tan(z))
return grad(lambda w: jnp.sum(g(w)))(x)
@vmap
def baseline_fun(x):
y = jnp.sum(jnp.sin(x))
@vmap
def g(z):
return 3. * jnp.exp(jnp.sin(x).sum() * jnp.cos(y) * jnp.tan(z))
return grad(lambda w: jnp.sum(g(w)))(x)
shape = mesh_shape + (4,)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
ans = grad(lambda x: jnp.sum(test_fun(x)))(x)
expected = grad(lambda x: jnp.sum(baseline_fun(x)))(x)
self.assertAllClose(ans, expected, atol=1e-3, rtol=1e-3)
def testArrays(self):
f = lambda x: 2 * x
f = self.pmap(f, axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
# test that we can pass in and out Arrays
y = f(x)
self.assertIsInstance(y, jax.Array)
self.assertIsInstance(y, array.ArrayImpl)
self.assertNotIsInstance(y, np.ndarray)
self.assertAllClose(y, 2 * x, check_dtypes=False)
z = f(y)
self.assertIsInstance(z, array.ArrayImpl)
self.assertNotIsInstance(z, np.ndarray)
self.assertAllClose(z, 2 * 2 * x, check_dtypes=False)
# test that we can pass in a regular Array
y = f(device_put(x))
self.assertIsInstance(y, array.ArrayImpl)
self.assertAllClose(y, 2 * x, check_dtypes=False)
# test that we can pass an Array to a regular jit computation
z = y + y
self.assertAllClose(z, 2 * 2 * x, check_dtypes=False)
# test that we can handle device movement on dispatch
bufs = y._arrays[::-1]
sharding = jax.sharding.PmapSharding(
[list(b.devices())[0] for b in bufs], y.sharding.sharding_spec)
y = jax.make_array_from_single_device_arrays(y.shape, sharding, bufs)
z = f(y)
self.assertAllClose(z, 2 * 2 * x[::-1], check_dtypes=False)
# test that the repr doesn't crash
repr(z)
# test that we can lexically capture a sda as a constant.
g = jit(lambda z: z + y)
self.assertAllClose(g(7), y + 7)
# Tests edge cases in lax._reshape_sharded_device_array
@parameterized.named_parameters(
{"testcase_name": f"_in={in_shape}_out={out_shape}"
.replace(" ", ""),
"in_shape": in_shape, "out_shape": out_shape}
for in_shape, out_shape in [
[(1,1), (1,)], [(1,), (1,1)], [(1,), ()], [(4,7), (2,2,7)]
])
def testArrayReshape(self, in_shape, out_shape):
if jax.device_count() < max(in_shape[:1] + out_shape[:1]):
raise SkipTest("not enough devices")
x = np.arange(math.prod(in_shape)).reshape(in_shape)
sharded_x = self.pmap(lambda x: x)(x)
self.assertAllClose(sharded_x.reshape(out_shape), x.reshape(out_shape),
check_dtypes=False)
def testPsumMultiple(self):
f = lambda x: lax.psum(x, ('i', 'j'))
f = self.pmap(self.pmap(f, 'i'), 'j')
def sum_and_broadcast(x, axis):
return np.repeat(np.sum(x, axis, keepdims=True), x.shape[axis], axis)
device_count = jax.device_count()
num_pairs, ragged = divmod(device_count, 2)
if num_pairs > 1 and not ragged:
shape = (num_pairs, 2, 4)
else:
shape = (device_count, 1, 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
ans = f(x)
expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1)
self.assertAllClose(ans, expected, check_dtypes=False)
def testPsumConstantReplicaGroups(self):
replicas = jax.device_count()
if replicas % 2 != 0:
raise SkipTest
axis_index_groups = np.arange(replicas).reshape(
2, replicas // 2).tolist()
f = lambda x: x - lax.psum(2., 'i', axis_index_groups=axis_index_groups)
f = self.pmap(f, 'i')
shape = (replicas, 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected_psum = 2. * replicas // 2
expected = x - expected_psum
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
@jtu.skip_on_devices("tpu")
def testPsumUnevenReplicaGroups(self):
replicas = jax.device_count()
if replicas <= 2:
raise SkipTest("Test expected devices greater than 2.")
axis_index_groups = [[0,1], np.arange(2,replicas)]
f = lambda x: x - lax.psum(x, 'i', axis_index_groups=axis_index_groups)
f = self.pmap(f, 'i')
shape = (replicas, 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
def sum_helper(a):
return np.broadcast_to(a.sum(0, keepdims=True),
(len(a), x.shape[1]))
expected_psum_1 = sum_helper(x[0:2])
expected_psum_2 = sum_helper(x[2:])
expected_psum = np.concatenate([expected_psum_1, expected_psum_2], 0)
expected = x - expected_psum
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testPsumReplicaGroups(self):
replicas = jax.device_count()
if replicas % 2 != 0:
raise SkipTest
axis_index_groups = np.arange(replicas).reshape(
2, replicas // 2).tolist()
f = lambda x: x - lax.psum(x, 'i', axis_index_groups=axis_index_groups)
f = self.pmap(f, 'i')
shape = (replicas, 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
def sum_helper(a):
return np.broadcast_to(a.sum(0, keepdims=True),
(replicas // 2, x.shape[1]))
expected_psum_1 = sum_helper(x[:replicas // 2])
expected_psum_2 = sum_helper(x[replicas // 2:])
expected_psum = np.concatenate([expected_psum_1, expected_psum_2], 0)
expected = x - expected_psum
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testGatherReplicaGroups(self):
replicas = jax.device_count()
if replicas % 2 != 0:
raise SkipTest("Test expected an even number of devices greater than 1.")
axis_index_groups = np.arange(replicas, dtype=np.int32)
axis_index_groups = axis_index_groups.reshape((replicas // 2, 2)).T
axis_index_groups = axis_index_groups.tolist()
f = lambda x: lax.all_gather(x, 'i', axis_index_groups=axis_index_groups)
f = self.pmap(f, 'i')
shape = (replicas, 4)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
ans = f(x)
group_1_result = x[0::2]
group_2_result = x[1::2]