forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patharray_test.py
1213 lines (1038 loc) · 47.3 KB
/
array_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Array."""
import contextlib
import math
import os
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
import jax.numpy as jnp
from jax._src import core
from jax._src import dispatch
from jax._src import op_shardings
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.util import safe_zip
from jax._src.sharding_impls import (_op_sharding_to_pos_sharding,
pmap_sharding_devices_indices_map)
from jax.experimental.pjit import pjit
from jax.experimental import multihost_utils
from jax.sharding import PartitionSpec as P
from jax._src import array
from jax._src import prng
from jax import config
config.parse_flags_with_absl()
prev_xla_flags = None
with contextlib.suppress(ImportError):
import pytest
pytestmark = pytest.mark.multiaccelerator
# Run all tests with 8 CPU devices.
def setUpModule():
global prev_xla_flags
prev_xla_flags = os.getenv("XLA_FLAGS")
flags_str = prev_xla_flags or ""
# Don't override user-specified device count, or other XLA flags.
if "xla_force_host_platform_device_count" not in flags_str:
os.environ["XLA_FLAGS"] = (flags_str +
" --xla_force_host_platform_device_count=8")
# Clear any cached backends so new CPU backend will pick up the env var.
xb.get_backend.cache_clear()
# Reset to previous configuration in case other test modules will be run.
def tearDownModule():
if prev_xla_flags is None:
del os.environ["XLA_FLAGS"]
else:
os.environ["XLA_FLAGS"] = prev_xla_flags
xb.get_backend.cache_clear()
def create_array(shape, sharding, global_data=None):
if global_data is None:
global_data = np.arange(math.prod(shape)).reshape(shape)
return array.make_array_from_callback(
shape, sharding, lambda idx: global_data[idx]), global_data
class JaxArrayTest(jtu.JaxTestCase):
def test_array_impl_name(self):
self.assertEqual(array.ArrayImpl.__name__, "ArrayImpl")
@parameterized.named_parameters(
("mesh_x_y", P("x", "y")),
("mesh_x", P("x")),
("mesh_y", P("y")),
("mesh_none_y", P(None, "y")),
("mesh_xy", P(("x", "y"))),
("mesh_fully_replicated", P()),
)
def test_jax_array_value(self, mesh_axes):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, global_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes))
for s in arr.addressable_shards:
self.assertTrue(dispatch.is_single_device_sharding(s.data.sharding))
self.assertArraysEqual(s.data, global_data[s.index])
self.assertArraysEqual(arr._value, global_data)
self.assertArraysEqual(arr._npy_value, global_data)
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"),
# There are more slices but for convenient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
(2, 1),
[0, 0, 0, 0, 0, 0, 0, 0], False),
("mesh_x", P("x"),
((slice(0, 2), slice(None)), (slice(0, 2), slice(None))),
(2, 2),
[0, 1, 0, 1, 0, 1, 0, 1], False),
("mesh_y", P("y"),
((slice(0, 4), slice(None)), (slice(4, 8), slice(None))),
(4, 2),
[0, 0, 1, 1, 2, 2, 3, 3], False),
("mesh_none_y", P(None, "y"),
((slice(None), slice(0, 1)), (slice(None), slice(1, 2))),
(8, 1),
[0, 0, 1, 1, 2, 2, 3, 3], False),
("mesh_xy", P(("x", "y")),
((slice(0, 1), slice(None)), (slice(1, 2), slice(None))),
(1, 2),
[0, 0, 0, 0, 0, 0, 0, 0], False),
("mesh_fully_replicated", P(),
((slice(None), slice(None)), (slice(None), slice(None))),
(8, 2),
[0, 1, 2, 3, 4, 5, 6, 7], True),
)
def test_array_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids, expected_is_fully_replicated):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
s = jax.sharding.NamedSharding(global_mesh, mesh_axes)
arr, global_input_data = create_array(global_input_shape, s)
self.assertEqual(arr.ndim, 2)
self.assertEqual(arr.size, 16)
self.assertEqual(arr.addressable_shards[0].index, expected_index[0])
self.assertEqual(arr.addressable_shards[1].index, expected_index[1])
replica_ids = [i.replica_id for i in arr.addressable_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
self.assertListEqual([i.device.id for i in arr.addressable_shards],
[0, 1, 2, 3, 4, 5, 6, 7])
self.assertEqual(arr.is_fully_replicated, expected_is_fully_replicated)
for i, s in enumerate(arr.addressable_shards):
self.assertEqual(s.data.aval,
core.ShapedArray(expected_shard_shape, s.data.dtype))
self.assertArraysEqual(s.data, global_input_data[s.index])
self.assertArraysEqual(s.data, arr.addressable_data(i))
for g, l in safe_zip(arr.global_shards, arr.addressable_shards):
self.assertEqual(g.device, l.device)
self.assertEqual(g.index, l.index)
self.assertEqual(g.replica_id, l.replica_id)
self.assertEqual(g.data.aval, l.data.aval)
self.assertArraysEqual(g.data, l.data)
def test_addressable_data(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
shape = (8, 2)
s = jax.sharding.NamedSharding(global_mesh, P(None))
arr, inp_data = create_array(shape, s)
for i in range(len(arr)):
self.assertArraysEqual(inp_data, arr.addressable_data(i))
def test_array_delete(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
arr.delete()
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
arr._check_if_deleted()
self.assertIsNone(arr._npy_value)
self.assertIsNone(arr._arrays)
def test_single_device_array_usage_after_delete(self):
x = jnp.array([1, 2, 3])
x.delete()
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
_ = x + 1
def test_multi_device_array_usage_after_delete(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
shape = (8, 2)
arr = jax.device_put(np.arange(math.prod(shape), dtype=np.int32),
jax.sharding.NamedSharding(global_mesh, P('x')))
arr.delete()
with self.assertRaisesRegex(
RuntimeError, r'Array has been deleted with shape=int32\[16\].'):
_ = arr + 1
def test_device_put(self):
numpy_array = np.array([1, 2, 3])
arr = jax.device_put(numpy_array, jax.devices()[0])
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
self.assertArraysEqual(arr, numpy_array)
self.assertEqual(arr._committed, True)
for i in arr.addressable_shards:
self.assertArraysEqual(i.data, numpy_array)
self.assertEqual(i.device, jax.devices()[0])
self.assertEqual(i.index, (slice(None),))
self.assertEqual(i.replica_id, 0)
def test_device_put_array_delete(self):
arr = jax.device_put(np.array([1, 2, 3]), jax.devices()[0])
arr.delete()
with self.assertRaisesRegex(RuntimeError, 'Array has been deleted.'):
arr._check_if_deleted()
self.assertIsNone(arr._npy_value)
self.assertIsNone(arr._arrays)
def test_array_device_get(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
self.assertArraysEqual(jax.device_get(arr), input_data)
def test_repr(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
self.assertStartsWith(repr(arr), "Array(")
def test_jnp_array(self):
arr = jnp.array([1, 2, 3])
self.assertIsInstance(arr, array.ArrayImpl)
self.assertTrue(dispatch.is_single_device_sharding(arr.sharding))
self.assertEqual(arr._committed, False)
self.assertFalse(arr.weak_type)
def test_jnp_array_jit_add(self):
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
arr = jax.jit(lambda x, y: x + y)(a, b)
self.assertIsInstance(arr, array.ArrayImpl)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
def test_jnp_array_jnp_add(self):
arr = jnp.add(jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
self.assertIsInstance(arr, array.ArrayImpl)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
def test_jnp_array_normal_add(self):
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
arr = a + b
self.assertIsInstance(arr, array.ArrayImpl)
self.assertArraysEqual(arr, np.array([5, 7, 9]))
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
def test_array_sharded_astype(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
arr_float32 = arr.astype(jnp.float32)
self.assertEqual(arr_float32.dtype, np.float32)
self.assertArraysEqual(arr_float32, input_data.astype(np.float32))
self.assertLen(arr_float32.addressable_shards, 8)
for i in arr_float32.addressable_shards:
self.assertArraysEqual(i.data, input_data[i.index].astype(np.float32))
def test_jnp_array_astype(self):
arr = jnp.array([1, 2, 3])
arr_float32 = arr.astype(jnp.float32)
self.assertEqual(arr_float32.dtype, np.float32)
self.assertArraysEqual(arr_float32, arr.astype(np.float32))
def test_array_delete_idempotent(self):
mesh = jtu.create_global_mesh((2,), ('x',))
arr = jax.device_put(np.arange(8), jax.sharding.NamedSharding(mesh, P('x')))
arr.delete()
self.assertTrue(arr.is_deleted())
arr.delete() # Run delete again to check if it's idempotent.
self.assertTrue(arr.is_deleted())
def test_sharded_add(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
a, input_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
b, _ = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x')))
out = a + b
expected = input_data + input_data
self.assertArraysEqual(out, expected)
self.assertLen(out.addressable_shards, 8)
for i in out.addressable_shards:
self.assertArraysEqual(i.data, expected[i.index])
def test_sharded_zeros_like(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
a, input_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
out = jnp.zeros_like(a)
expected = jnp.zeros(input_data.shape, dtype=a.dtype)
self.assertArraysEqual(out, expected)
self.assertLen(out.addressable_shards, 8)
for i in out.addressable_shards:
self.assertArraysEqual(i.data, expected[i.index])
def test_zeros_like(self):
a = jnp.array([1, 2, 3], dtype=np.int32)
out = jnp.zeros_like(a)
expected = np.zeros(a.shape, dtype=np.int32)
self.assertArraysEqual(out, expected)
self.assertTrue(dispatch.is_single_device_sharding(out.sharding))
def test_wrong_num_arrays(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
devices = jax.local_devices()[:8] # Taking up to 8 devices
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
di_map = s.devices_indices_map(shape)
bufs = [jax.device_put(inp_data[di_map[d]], d) for d in devices]
with self.assertRaisesRegex(
ValueError,
r'Expected 8 per-device arrays \(this is how many devices are addressable '
r'by the sharding\), but got 4'):
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs[:4], committed=True)
with self.assertRaisesRegex(
ValueError,
r'Expected 8 per-device arrays \(this is how many devices are addressable '
r'by the sharding\), but got 16'):
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs + bufs, committed=True)
def test_arrays_not_in_device_assignment(self):
if jax.device_count() < 4:
self.skipTest('Requires more than 4 devices')
shape = (8, 2)
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
# sharding device ids = {0, 1}
s = jax.sharding.NamedSharding(mesh, P('x'))
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
# _arrays device ids = {2, 3}
bufs = [jax.device_put(inp_data, d) for d in jax.devices()[2:4]]
with self.assertRaisesRegex(
ValueError,
"Addressable devices and per-device arrays devices do not match. "
"Sharding contains devices {0, 1} that are not present in per-device "
"arrays. Per-device arrays contain devices {2, 3} that are not present "
"in the sharding."):
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
def test_more_devices_in_sharding_than_arrays(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
# Sharding device ids = {0, 1}
s = jax.sharding.NamedSharding(mesh, P('x'))
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
# _arrays device ids = {0, 0}
bufs = [jax.device_put(inp_data, jax.devices()[0]) for _ in range(2)]
with self.assertRaisesRegex(
ValueError,
"Addressable devices and per-device arrays devices do not match. "
r"Sharding contains devices \{1\} that are not present in per-device "
"arrays."):
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
def test_different_devices_in_arrays_than_sharding(self):
if jax.device_count() < 3:
self.skipTest('Requires more than 3 devices')
shape = (8, 2)
mesh = jax.sharding.Mesh(np.array([jax.devices()[1], jax.devices()[2]]), ('x'))
# sharding device ids = {1, 2}
s = jax.sharding.NamedSharding(mesh, P('x'))
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
# _arrays device ids = {0, 1}
bufs = [jax.device_put(inp_data, d) for d in jax.devices()[:2]]
with self.assertRaisesRegex(
ValueError,
"Addressable devices and per-device arrays devices do not match. "
r"Sharding contains devices \{2\} that are not present in per-device "
r"arrays. Per-device arrays contain devices \{0\} that are not present "
"in the sharding."):
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"), (2, 2)),
("mesh_x", P("x"), (2, 4)),
("mesh_y", P("y"), (4, 4)),
("mesh_none_y", P(None, "y"), (8, 2)),
("mesh_none_x", P(None, "x"), (8, 1)),
("mesh_xy", P(("x", "y")), (1, 4)),
("mesh_replicated", P(()), (8, 4)),
)
def test_shard_shape_mismatch_with_buffer_shape(self, pspec, expected_shard_shape):
shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = jax.sharding.NamedSharding(mesh, pspec)
inp_data = np.arange(5)
str_expected_shard_shape = str(expected_shard_shape).replace(
r"(", r"\(").replace(r")", r"\)")
with self.assertRaisesRegex(
ValueError,
f"Expected shard shape {str_expected_shard_shape} doesn't match the "
"single device array shape"):
array.make_array_from_callback(shape, mps, lambda idx: inp_data)
def test_mismatch_dtype(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
inp_data = np.arange(math.prod(shape), dtype=np.int32).reshape(shape)
indices = s.devices_indices_map(shape)
bufs = [jax.device_put(inp_data[indices[d]], d) for d in mesh.local_devices]
with self.assertRaisesRegex(
ValueError,
"Input buffers to `Array` must have matching dtypes. "
"Got int32, expected float32"):
array.ArrayImpl(core.ShapedArray(shape, np.float32), s, bufs, committed=True)
def test_array_iter_pmap_sharding(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices.')
x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
y = jax.pmap(jnp.sin)(x)
self.assertArraysEqual([list(a.devices())[0] for a in y],
y.sharding._device_assignment,
allow_object_dtype=True)
sin_x = iter(np.sin(x))
for i, j in zip(iter(y), sin_x):
self.assertIsInstance(i, array.ArrayImpl)
self.assertArraysAllClose(i, j)
def test_array_iter_pmap_sharding_last_dim_sharded(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices.')
x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
y = jax.pmap(jnp.sin, out_axes=1)(x)
for i, j in zip(iter(y), iter(np.sin(x).T)):
self.assertArraysAllClose(i, j)
def test_array_iter_mesh_pspec_sharding_multi_device(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
for i, j in zip(iter(arr), iter(input_data)):
self.assertIsInstance(i, array.ArrayImpl)
self.assertArraysEqual(i, j)
def test_array_iter_replicated_multi_device(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P(None)))
for i, j in zip(iter(arr), iter(input_data)):
self.assertIsInstance(i, array.ArrayImpl)
self.assertArraysEqual(i, j)
self.assertLen(i.sharding.device_set, 8)
self.assertTrue(
op_shardings.are_op_shardings_equal(
arr.sharding._to_xla_hlo_sharding(arr.ndim),
i.sharding._to_xla_hlo_sharding(i.ndim)))
def test_array_getitem_mesh_pspec_sharding_multi_device(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
s = arr[2:4, 0:1]
self.assertIsInstance(s, array.ArrayImpl)
self.assertArraysEqual(s, input_data[2:4, 0:1])
p = arr[:2]
self.assertIsInstance(p, array.ArrayImpl)
self.assertArraysEqual(p, input_data[:2])
def test_array_getitem_compile_multi_device_sharding(self):
def _check(out, inp, shard_shape):
self.assertArraysEqual(out, inp)
self.assertEqual(out.sharding.shard_shape(out.shape), shard_shape)
self.assertNotIsInstance(out.sharding, jax.sharding.SingleDeviceSharding)
global_mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
input_shape = (4, 4, 2)
arr, np_inp = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y', 'z')))
_check(arr[:, -1, :], np_inp[:, -1, :], (2, 1))
_check(arr[0, 0, 0], np_inp[0, 0, 0], ())
_check(arr[-1, -1, :], np_inp[-1, -1, :], (1,))
_check(arr[:, 1, 0], np_inp[:, 1, 0], (2,))
_check(arr[:, :, :], np_inp[:, :, :], (2, 2, 1))
_check(arr[3, :, :], np_inp[3, :, :], (2, 1))
_check(arr[-1, -1, -1], np_inp[-1, -1, -1], ())
_check(arr[2, -1, :], np_inp[2, -1, :], (1,))
_check(arr[2, 3, 1], np_inp[2, 3, 1], ())
_check(arr[-1], np_inp[-1], (2, 1))
_check(arr[:], np_inp[:], (2, 2, 1))
_check(arr[np.array(0), :, :], np_inp[np.array(0), :, :], (2, 1))
_check(arr[jnp.array(0), :, :], np_inp[jnp.array(0), :, :], (2, 1))
_check(arr[0, :2, 1], np_inp[0, :2, 1], (2,))
_check(arr[:, 1::2], np_inp[:, 1::2], (2, 2, 1))
_check(arr[:, -1:, :], np_inp[:, -1:, :], (2, 1, 1))
_check(arr[0:6:1], np_inp[0:6:1], (2, 2, 1))
_check(arr[:4], np_inp[:4], (2, 2, 1))
_check(arr[::-1], np_inp[::-1], (2, 2, 1))
_check(arr[1], np_inp[1], (2, 1))
def test_array_getitem_replicated_multi_device(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P(None)))
s = arr[2:4, 0:1]
self.assertIsInstance(s, array.ArrayImpl)
self.assertArraysEqual(s, np.array([[4], [6]]))
self.assertLen(s.sharding.device_set, 8)
self.assertTrue(
op_shardings.are_op_shardings_equal(
arr.sharding._to_xla_hlo_sharding(arr.ndim),
s.sharding._to_xla_hlo_sharding(s.ndim)))
p = arr[:2]
self.assertIsInstance(p, array.ArrayImpl)
self.assertArraysEqual(p, input_data[:2])
self.assertLen(s.sharding.device_set, 8)
self.assertTrue(
op_shardings.are_op_shardings_equal(
arr.sharding._to_xla_hlo_sharding(arr.ndim),
s.sharding._to_xla_hlo_sharding(s.ndim)))
def test_array_iter_mesh_pspec_sharding_single_device(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices.')
single_dev = jax.devices()[1:2]
mesh = jax.sharding.Mesh(np.array(single_dev), ('x'))
input_shape = (8, 2)
arr, input_data = create_array(
input_shape, jax.sharding.NamedSharding(mesh, P('x')))
for i, j in zip(arr, iter(input_data)):
self.assertArraysEqual(i, j)
self.assertEqual(i.devices(), {single_dev[0]})
def test_array_shards_committed(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices.')
x = jnp.array([1, 2, 3])
for s in x.addressable_shards:
self.assertEqual(s.data._committed, x._committed)
self.assertFalse(s.data._committed)
y = jax.device_put(x, jax.devices()[1])
for s in y.addressable_shards:
self.assertEqual(s.data._committed, y._committed)
self.assertTrue(s.data._committed)
def test_array_jnp_array_copy_multi_device(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
c_arr = jnp.array(arr, copy=True)
self.assertArraysEqual(arr, c_arr)
self.assertEqual(arr._committed, c_arr._committed)
for a, c in safe_zip(arr.addressable_shards, c_arr.addressable_shards):
self.assertArraysEqual(a.data, c.data)
self.assertEqual(a.index, c.index)
self.assertEqual(a.replica_id, c.replica_id)
self.assertEqual(a.device, c.device)
self.assertNotEqual(a.data.unsafe_buffer_pointer(),
c.data.unsafe_buffer_pointer())
def test_array_addressable_shards(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
arr, _ = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
for a in arr.addressable_shards:
self.assertIsInstance(a.data, array.ArrayImpl)
x = jnp.array([1, 2, 3])
self.assertIsInstance(x.addressable_data(0), array.ArrayImpl)
def test_shape_dtype_struct_sharding_jit(self):
mesh = jtu.create_global_mesh((8,), ('x'))
s = jax.sharding.NamedSharding(mesh, P('x'))
x_dummy = jax.ShapeDtypeStruct(
shape=(16,),
dtype=jnp.dtype('float32'),
sharding=s)
def f(x):
return x * 2
c = jax.jit(f).lower(x_dummy).compile()
input_shardings, output_shardings = c.input_shardings, c.output_shardings
self.assertLen(input_shardings, 2)
self.assertEqual(input_shardings[1], {})
self.assertEqual(input_shardings[1], {})
self.assertTrue(
op_shardings.are_op_shardings_equal(
input_shardings[0][0]._to_xla_hlo_sharding(x_dummy.ndim),
s._to_xla_hlo_sharding(x_dummy.ndim)))
self.assertTrue(
op_shardings.are_op_shardings_equal(
output_shardings._to_xla_hlo_sharding(x_dummy.ndim),
s._to_xla_hlo_sharding(x_dummy.ndim)))
def test_shape_dtype_struct_sharding_pjit(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
def f(x):
return x * 2.
x_dummy = jax.ShapeDtypeStruct(
shape=(8, 2),
dtype=jnp.dtype('float32'),
sharding=s)
c = pjit(f).lower(x_dummy).compile()
input_shardings, output_shardings = c.input_shardings, c.output_shardings
self.assertTrue(
op_shardings.are_op_shardings_equal(
input_shardings[0][0]._to_xla_hlo_sharding(x_dummy.ndim),
s._to_xla_hlo_sharding(x_dummy.ndim)))
self.assertTrue(
op_shardings.are_op_shardings_equal(
output_shardings._to_xla_hlo_sharding(x_dummy.ndim),
s._to_xla_hlo_sharding(x_dummy.ndim)))
# TODO(skyewm): remove this test when we can remove the workaround manual
# defragment API
@jtu.skip_on_devices('cpu') # defragment not implemented for TFRT CPU
def test_defragment(self):
if xb.using_pjrt_c_api():
self.skipTest("Manual defragment not exposed via PJRT C API")
# Create a few arrays
global_mesh = jtu.create_global_mesh((jax.local_device_count(),), ('x',))
shape = (8, 2)
mpsharding = jax.sharding.NamedSharding(global_mesh, P('x',))
arr1, data = create_array(shape, mpsharding)
arr2, _ = create_array(shape, mpsharding, data)
arr3, _ = create_array(shape, mpsharding, data)
# Delete one of them
arr2.delete()
# Defragment
xb.get_backend().defragment()
# Sanity check remaining arrays
self.assertArraysEqual(arr1, data)
self.assertArraysEqual(arr1 + arr3, data * 2)
# TODO(skyewm): check that defragmentation actually happened. I originally
# thought to do this with unsafe_buffer_pointer(), but that's not always the
# device memory address. Other ideas include causing enough fragmentation to
# OOM, and exposing allocator stats in Python.
def test_on_device_size_in_bytes(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
a, _ = create_array(
(8, 2), jax.sharding.NamedSharding(global_mesh, P('x', 'y')))
shard_size = a.addressable_shards[0].data.on_device_size_in_bytes()
self.assertGreaterEqual(shard_size, 4 * 2)
self.assertEqual(shard_size * len(a.global_shards),
a.on_device_size_in_bytes())
def test_array_is_ready(self):
x = jax.device_put(jnp.arange(8.), jax.devices()[0])
x.is_ready() # doesn't crash
def test_process_allgather_single_host(self):
x = jnp.arange(8.)
out = multihost_utils.process_allgather(x)
self.assertEqual(out.shape, x.shape)
self.assertArraysEqual(out, x)
@jtu.sample_product(
dtype=jtu.dtypes.all,
shape=[(), (10), (2, 3)],
)
@jtu.run_on_devices("cpu")
def test_buffer_protocol(self, dtype, shape):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
y = jax.device_put(x)
if dtype == jax.dtypes.bfloat16:
with self.assertRaisesRegex(
BufferError,
'Buffers of type BF16 are not supported by the Python buffer '
'protocol.'
):
memoryview(y)
return
x_bytes = memoryview(x).tobytes()
y_bytes = memoryview(y).tobytes()
self.assertEqual(x_bytes, y_bytes)
@jtu.run_on_devices("cpu")
def test_buffer_protocol_deletion(self):
rng = jtu.rand_default(self.rng())
x = rng((3, 4), np.float32)
y = jax.device_put(x)
x_bytes = memoryview(x).tobytes()
y_view = memoryview(y)
# The array does not actually get deleted until any external reference is
# dropped. Arguably we should make calling delete() in these circumstances
# return an error instead, but that would be a behavior change for existing
# users.
y.delete()
y_bytes = y_view.tobytes()
self.assertEqual(x_bytes, y_bytes)
def test_array_copy_to_host_async(self):
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
x = pjit(lambda: jnp.arange(8.),
out_shardings=jax.sharding.NamedSharding(global_mesh, P(None)))()
self.assertLen(x.sharding.device_set, 4)
x.copy_to_host_async() # doesn't crash
self.assertArraysEqual(np.arange(8.), x)
def test_array_fully_replicated_shard(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
inp_shape = (8, 2)
arr, inp_data = create_array(
inp_shape, jax.sharding.NamedSharding(global_mesh, P()))
fs = arr._fully_replicated_shard()
self.assertEqual(fs.shape, inp_shape)
self.assertTrue(dispatch.is_single_device_sharding(fs.sharding))
self.assertArraysEqual(fs, inp_data)
self.assertArraysEqual(arr.addressable_data(0), inp_data)
def test_shard_array_to_fully_replicated(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(global_mesh, P())
arr = jnp.arange(16)
self.assertFalse(arr._committed)
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
out = jax.jit(lambda x: x * 2, in_shardings=sharding)(arr)
self.assertTrue(out.sharding.is_fully_replicated)
self.assertArraysEqual(out, arr * 2)
def test_fully_replicated_donated_array_is_deleted(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(global_mesh, P())
arr = jnp.arange(16)
arr_copy = arr.copy()
self.assertFalse(arr._committed)
self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding)
out = jax.jit(lambda x: x * 2, in_shardings=sharding, donate_argnums=0)(arr)
self.assertTrue(out.sharding.is_fully_replicated)
self.assertArraysEqual(out, arr_copy * 2)
self.assertTrue(arr.is_deleted())
@parameterized.product(dtype=jtu.dtypes.all + jtu.dtypes.custom_floats)
@unittest.skipIf(xla_extension_version < 208, "Test requires jaxlib > 0.4.19")
def test_shards_have_correct_dtype(self, dtype):
x = jnp.ones((), dtype=dtype)
for shard in x.addressable_shards:
self.assertEqual(shard.data.dtype, dtype)
class ShardingTest(jtu.JaxTestCase):
def test_mesh_pspec_sharding_interface(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
pspec = P('y', 'x')
global_shape = (8, 4)
mp_sharding = jax.sharding.NamedSharding(mesh, pspec)
di_map = mp_sharding.devices_indices_map(global_shape)
hlo_sharding = mp_sharding._to_xla_hlo_sharding(len(global_shape))
device_assignment = mp_sharding._device_assignment
self.assertEqual(di_map[mesh.devices.flat[0]], (slice(0, 4), slice(0, 1)))
self.assertArraysEqual(device_assignment, list(mesh.devices.flat),
allow_object_dtype=True)
self.assertTrue(hlo_sharding.is_tiled())
self.assertListEqual(hlo_sharding.tile_assignment_dimensions(), [2, 4])
self.assertListEqual(hlo_sharding.tile_assignment_devices(),
[0, 2, 4, 6, 1, 3, 5, 7])
@parameterized.named_parameters(
("mesh_x_y", P("x", "y")),
("mesh_x", P("x")),
("mesh_y", P("y")),
("mesh_none_y", P(None, "y")),
("mesh_none_x", P(None, "x")),
("mesh_xy", P(("x", "y"))),
("mesh_fully_replicated", P()),
)
def test_op_sharding_indices(self, pspec):
shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = jax.sharding.NamedSharding(mesh, pspec)
ops = jax.sharding.GSPMDSharding(
list(mesh.devices.flat), mps._to_xla_hlo_sharding(len(shape)))
self.assertDictEqual(
ops.devices_indices_map(shape), mps.devices_indices_map(shape))
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"), (2, 2)),
("mesh_x", P("x"), (2, 4)),
("mesh_y", P("y"), (4, 4)),
("mesh_none_y", P(None, "y"), (8, 2)),
("mesh_none_x", P(None, "x"), (8, 1)),
("mesh_xy", P(("x", "y")), (1, 4)),
("mesh_fully_replicated", P(), (8, 4)),
)
def test_shard_shape(self, pspec, expected_shard_shape):
shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = jax.sharding.NamedSharding(mesh, pspec)
self.assertEqual(mps.shard_shape(shape), expected_shard_shape)
def test_uneven_shard_error(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = jax.sharding.NamedSharding(mesh, P('x', 'y'))
with self.assertRaisesRegex(
ValueError,
r"Sharding.*implies that array axis 1 is partitioned 2 times, but the "
r"dimension size is 3 \(full shape: \(8, 3\), per-dimension tiling "
r"factors: \[4, 2\] should evenly divide the shape\)"):
mps.shard_shape((8, 3))
def test_pmap_sharding_hash_eq(self):
if jax.device_count() < 2:
self.skipTest('Test needs >= 2 devices.')
shape = (2, 2)
num_elements = math.prod(shape)
inp_data = np.arange(num_elements).reshape(shape)
out = jax.pmap(lambda x: x)(inp_data)
self.assertIsInstance(out.sharding, jax.sharding.PmapSharding)
# Populate the device_indices_map cache.
_ = out.sharding.devices_indices_map(shape)
cache_info1 = pmap_sharding_devices_indices_map.cache_info()
inp_data2 = np.arange(num_elements, num_elements + num_elements).reshape(shape)
out2 = jax.pmap(lambda x: x)(inp_data2)
# Populate the device_indices_map cache.
_ = out2.sharding.devices_indices_map(shape)
cache_info2 = pmap_sharding_devices_indices_map.cache_info()
self.assertGreater(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
def test_is_compatible_error(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((1, 1, 2), ('replica', 'data', 'mdl'))
mps = jax.sharding.NamedSharding(mesh, P(None, ('mdl',), None, None))
new_mps = jax.sharding.NamedSharding._from_parsed_pspec(
mps.mesh, mps._parsed_pspec)
with self.assertRaisesRegex(
ValueError,
r"Sharding NamedSharding\(mesh=Mesh\('replica': 1, 'data': 1, 'mdl': 2\), "
r"spec=PartitionSpec\(None, \('mdl',\), None, None\).*\) is only "
"valid for values of rank at least 4, but was applied to a value of rank 2"):
new_mps.is_compatible_aval(shape)
def test_is_subclass(self):
# array version of api_test.py::APITest::test_is_subclass
self.assertTrue(issubclass(array.ArrayImpl, jax.Array))
self.assertFalse(issubclass(array.ArrayImpl, np.ndarray))
def test_gspmd_sharding_repr(self):
op = xc.OpSharding()
op.type = xc.OpSharding.Type.OTHER
op.tile_assignment_dimensions = [4, 1, 2]
op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
op.replicate_on_last_tile_dim = True
s = jax.sharding.GSPMDSharding(jax.devices(), op)
# memory kind also appears in the repr but only for TPU.
self.assertIn(
'GSPMDSharding({devices=[4,1,2]0,1,2,3,4,5,6,7 '
'last_tile_dim_replicate}', repr(s))
op2 = xc.OpSharding()
op2.type = xc.OpSharding.Type.REPLICATED
s2 = jax.sharding.GSPMDSharding(jax.devices(), op2)
# memory kind also appears in the repr but only for TPU.
self.assertIn('GSPMDSharding({replicated}', repr(s2))
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"), (4, 2), (), False),
("mesh_x", P("x"), (4, 2), (1,), False),
("mesh_y", P("y"), (4, 2), (0,), True),
("mesh_none_y", P(None, "y"), (4, 2), (0,), False),
("mesh_none_x", P(None, "x"), (4, 2), (1,), True),
("mesh_xy", P(("x", "y")), (8, 1), (), False),
("mesh_fully_replicated", P(), (4, 2), None, False),
)
def test_positional_sharding_op_sharding_lowering(
self, pspec, shape, axes, transpose):
value_shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = jax.sharding.NamedSharding(mesh, pspec)
devices = jax.local_devices()[:8] # Taking up to 8 devices
devices_sharding = jax.sharding.PositionalSharding(devices)
devices_sharding = devices_sharding.reshape(shape).replicate(axes)
if transpose:
devices_sharding = devices_sharding.T
op1 = mps._to_xla_hlo_sharding(len(value_shape))
op2 = devices_sharding._to_xla_hlo_sharding(len(value_shape))
self.assertEqual(mps.shard_shape(value_shape),
devices_sharding.shard_shape(value_shape))
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
@parameterized.named_parameters(
("2d_mesh_x_y", (4, 2), P("x", "y")),
("2d_mesh_x", (4, 2), P("x")),
("2d_mesh_y", (4, 2), P("y")),
("2d_mesh_none_y", (4, 2), P(None, "y")),
("2d_mesh_none_x", (4, 2), P(None, "x")),
("2d_mesh_xy", (4, 2), P(("x", "y"))),
("2d_mesh_none_xy", (4, 2), P(None, ("x", "y"))),
("2d_mesh_x_none", (2, 1), P(('x',), None)),
("2d_mesh_fully_replicated", (4, 2), P()),
("3d_mesh_none_none_z", (2, 2, 2), P(None, None, 'z')),
("3d_mesh_none_y_none", (2, 2, 2), P(None, 'y', None)),
("3d_mesh_x_y_none", (2, 2, 2), P('x', 'y', None)),
("3d_mesh_none_yz", (2, 2, 2), P(None, ('y', 'z'))),
("3d_mesh2_none_none_z", (1, 2, 4), P(None, None, 'z')),
("3d_mesh2_x_none_none", (1, 2, 4), P('x', None, None)),
("3d_mesh_x_none_none", (2, 1, 1), P('x', None, None)),
)
def test_positional_sharding_from_op_sharding(self, mesh_shape, pspec):
ndim = len(mesh_shape)
mesh = jtu.create_global_mesh(
mesh_shape, ('x', 'y') if ndim == 2 else ('x', 'y', 'z'))
mps = jax.sharding.NamedSharding(mesh, pspec)
original_op_sharding = mps._to_xla_hlo_sharding(ndim)
ps = _op_sharding_to_pos_sharding(original_op_sharding,
mps._device_assignment)
out_op_sharding = ps._to_xla_hlo_sharding(ndim)
self.assertTrue(op_shardings.are_op_shardings_equal(
original_op_sharding, out_op_sharding))
@parameterized.named_parameters(
("2d_mesh_x", (1, 1), P("x", "y")),
("2d_mesh_x_y", (4, 2), P("x", "y")),
("2d_mesh_empty", (2, 1), P()),
("2d_mesh_p_none", (2, 1), P(None)),
("2d_mesh_none_none", (2, 1), P(None, None)),
("2d_mesh_tuple_empty", (2, 1), P((),)),
("2d_mesh_x_none", (2, 1), P(('x',), None)),
("2d_mesh_xy_none", (2, 1), P(('x', 'y'), None)),
("2d_mesh_none", (2, 1), None),
("2d_mesh_x_tuple_empty", (2, 1), P('x', (), (), ())),
("2d_mesh_3_tuple_empty", (2, 1), P((), (), ())),
("3d_mesh2_x_none_none", (1, 2, 4), P('x', None, None)),
("3d_mesh2_x_y_none", (1, 1, 4), P('x', 'y', None)),
("3d_mesh2_xy_none", (1, 1, 4), P(('x', 'y'), None)),
)
def test_is_fully_replicated_named_sharding(self, mesh_shape, pspec):
if len(mesh_shape) == 2:
axis_names = ('x', 'y')
elif len(mesh_shape) == 3:
axis_names = ('x', 'y', 'z')
else:
axis_names = ('x',)
mesh = jtu.create_global_mesh(mesh_shape, axis_names)
mps = jax.sharding.NamedSharding(mesh, pspec)
shape = (8, 2, 4)
mps_op_sharding = mps._to_xla_hlo_sharding(len(shape))