forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
global_device_array_test.py
383 lines (340 loc) · 15.4 KB
/
global_device_array_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for GlobalDeviceArray."""
from absl.testing import absltest
from absl.testing import parameterized
import unittest
import numpy as np
import jax
from jax import core
from jax._src import test_util as jtu
from jax._src.util import prod, safe_zip
from jax.experimental import PartitionSpec as P
from jax.experimental.maps import Mesh
import jax.experimental.global_device_array as gda_lib
from jax.experimental.global_device_array import GlobalDeviceArray, get_shard_indices
from jax.config import config
config.parse_flags_with_absl()
def create_gda(global_shape, global_mesh, mesh_axes, global_data=None):
if global_data is None:
global_data = np.arange(prod(global_shape)).reshape(global_shape)
return GlobalDeviceArray.from_callback(
global_shape, global_mesh, mesh_axes, lambda idx: global_data[idx]), global_data
class GDATest(jtu.JaxTestCase):
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"),
# There are more slices but for convienient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
(2, 1),
[0, 0, 0, 0, 0, 0, 0, 0], False),
("mesh_x", P("x"),
((slice(0, 2), slice(None)), (slice(0, 2), slice(None))),
(2, 2),
[0, 1, 0, 1, 0, 1, 0, 1], False),
("mesh_y", P("y"),
((slice(0, 4), slice(None)), (slice(4, 8), slice(None))),
(4, 2),
[0, 0, 1, 1, 2, 2, 3, 3], False),
("mesh_none_y", P(None, "y"),
((slice(None), slice(0, 1)), (slice(None), slice(1, 2))),
(8, 1),
[0, 0, 1, 1, 2, 2, 3, 3], False),
("mesh_xy", P(("x", "y")),
((slice(0, 1), slice(None)), (slice(1, 2), slice(None))),
(1, 2),
[0, 0, 0, 0, 0, 0, 0, 0], False),
("mesh_fully_replicated", P(),
((slice(None), slice(None)), (slice(None), slice(None))),
(8, 2),
[0, 1, 2, 3, 4, 5, 6, 7], True),
)
def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids, expected_is_fully_replicated):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.ndim, 2)
self.assertEqual(gda.size, 16)
self.assertEqual(gda.mesh_axes, mesh_axes)
self.assertEqual(gda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.local_data(0),
global_input_data[expected_index[0]])
self.assertEqual(gda.local_shards[1].index, expected_index[1])
self.assertArraysEqual(gda.local_data(1),
global_input_data[expected_index[1]])
self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
replica_ids = [i.replica_id for i in gda.local_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
self.assertListEqual([i.device.id for i in gda.local_shards],
[0, 1, 2, 3, 4, 5, 6, 7])
self.assertEqual(gda.is_fully_replicated, expected_is_fully_replicated)
for s in gda.local_shards:
self.assertEqual(s.data.aval,
core.ShapedArray(expected_shard_shape, s.data.dtype))
for g, l in safe_zip(gda.global_shards, gda.local_shards):
self.assertEqual(g.device, l.device)
self.assertEqual(g.index, l.index)
self.assertEqual(g.replica_id, l.replica_id)
self.assertEqual(g.data.aval, l.data.aval)
self.assertArraysEqual(g.data, l.data)
@parameterized.named_parameters(
("mesh_x_y_z", P("x", "y", "z"),
# There are more slices but for convienient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 4), slice(0, 2), slice(0, 1)), (slice(0, 4), slice(0, 2), slice(1, 2))),
(4, 2, 1),
[0, 0, 0, 0, 0, 0, 0, 0]),
("mesh_xy_z", P(("x", "y"), "z"),
((slice(0, 2), slice(0, 2), slice(None)), (slice(0, 2), slice(2, 4), slice(None))),
(2, 2, 2),
[0, 0, 0, 0, 0, 0, 0, 0]),
("mesh_z", P("z"),
((slice(0, 4), slice(None), slice(None)), (slice(4, 8), slice(None), slice(None))),
(4, 4, 2),
[0, 0, 1, 1, 2, 2, 3, 3]),
)
def test_gda_3d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids):
global_mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
global_input_shape = (8, 4, 2)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.ndim, 3)
self.assertEqual(gda.size, 64)
self.assertEqual(gda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.local_data(0),
global_input_data[expected_index[0]])
self.assertEqual(gda.local_shards[1].index, expected_index[1])
self.assertArraysEqual(gda.local_data(1),
global_input_data[expected_index[1]])
self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
replica_ids = [i.replica_id for i in gda.local_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
@parameterized.named_parameters(
("mesh_x", P("x"),
# There are more slices but for convienient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 2),), (slice(2, 4),)),
(2,),
[0, 0, 0, 0, 0, 0, 0, 0]),
("mesh_none", P(),
((slice(None),), (slice(None),)),
(16,),
[0, 1, 2, 3, 4, 5, 6, 7]),
)
def test_gda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape,
expected_replica_ids):
global_mesh = jtu.create_global_mesh((8,), ('x'))
global_input_shape = (16,)
global_input_data = np.arange(prod(global_input_shape)).reshape(-1)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.ndim, 1)
self.assertEqual(gda.size, 16)
self.assertEqual(gda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.local_data(0),
global_input_data[expected_index[0]])
self.assertEqual(gda.local_shards[1].index, expected_index[1])
self.assertArraysEqual(gda.local_data(1),
global_input_data[expected_index[1]])
self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
replica_ids = [i.replica_id for i in gda.local_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
def test_gda_shape_0_1d_mesh(self):
global_mesh = jtu.create_global_mesh((8,), ('x'))
global_input_shape = (0,)
mesh_axes = P(None)
def cb(index):
return np.array([])
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.ndim, 1)
self.assertEqual(gda.size, 0)
for i, s in enumerate(gda.local_shards):
self.assertEqual(s.index, (slice(None),))
self.assertEqual(s.replica_id, i)
self.assertArraysEqual(s.data.to_py(), np.array([]))
self.assertEqual(gda.dtype, np.float32)
self.assertEqual(
gda_lib.get_shard_shape(global_input_shape, global_mesh, mesh_axes),
(0,))
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"),
# There are more slices but for convienient purposes, checking for only
# 2. The indices + shard_shape + replica_id should be unique enough.
((slice(0, 4), slice(0, 1)), (slice(0, 4), slice(1, 2))),
(4, 1),
[0, 0, 0, 0]),
)
def test_gda_subset_devices(self, mesh_axes, expected_index,
expected_shard_shape, expected_replica_ids):
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (8, 2)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
mesh_axes, cb)
self.assertEqual(gda.local_shards[0].index, expected_index[0])
self.assertArraysEqual(gda.local_data(0),
global_input_data[expected_index[0]])
self.assertEqual(gda.local_shards[1].index, expected_index[1])
self.assertArraysEqual(gda.local_data(1),
global_input_data[expected_index[1]])
self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
replica_ids = [i.replica_id for i in gda.local_shards]
self.assertListEqual(replica_ids, expected_replica_ids)
for g, l in safe_zip(gda.global_shards, gda.local_shards):
self.assertEqual(g.device, l.device)
self.assertEqual(g.index, l.index)
self.assertEqual(g.replica_id, l.replica_id)
self.assertArraysEqual(g.data, l.data)
def test_gda_batched_callback(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(('x', 'y'))
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(indices):
self.assertEqual(len(indices), len(global_mesh.local_devices))
return [global_input_data[index] for index in indices]
gda = GlobalDeviceArray.from_batched_callback(
global_input_shape, global_mesh, mesh_axes, cb)
expected_first_shard_value = np.array([[0, 1]])
self.assertArraysEqual(gda.local_data(0).to_py(),
expected_first_shard_value)
expected_second_shard_value = np.array([[2, 3]])
self.assertArraysEqual(gda.local_data(1).to_py(),
expected_second_shard_value)
def test_gda_batched_callback_with_devices(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x')
global_input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
def cb(cb_inp):
self.assertLen(cb_inp, 4)
dbs = []
for inp in cb_inp:
index, devices = inp
self.assertLen(devices, 2)
array = global_input_data[index]
dbs.extend([jax.device_put(array, device) for device in devices])
return dbs
gda = GlobalDeviceArray.from_batched_callback_with_devices(
global_input_shape, global_mesh, mesh_axes, cb)
expected_first_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
self.assertArraysEqual(gda.local_data(0).to_py(),
expected_first_shard_value)
expected_second_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
self.assertArraysEqual(gda.local_data(1).to_py(),
expected_second_shard_value)
def test_gda_str_repr(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(('x', 'y'))
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
self.assertEqual(str(gda),
'GlobalDeviceArray(shape=(8, 2), dtype=int32)')
self.assertEqual(
repr(gda), ('GlobalDeviceArray(shape=(8, 2), dtype=int32, '
"global_mesh_shape={'x': 4, 'y': 2}, "
"mesh_axes=PartitionSpec(('x', 'y'),))"))
def test_gda_equality_raises_not_implemented(self):
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(None,)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
input_gda = GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
same_input_gda = GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
with self.assertRaisesRegex(NotImplementedError,
'GlobalDeviceArray equality is intentionally unimplemented.'):
input_gda == same_input_gda
def test_mesh_hash(self):
global_mesh1 = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_mesh2 = jtu.create_global_mesh((2, 4), ('x', 'y'))
global_mesh3 = jtu.create_global_mesh((4, 2), ('x', 'y'))
self.assertNotEqual(hash(global_mesh1), hash(global_mesh2))
self.assertEqual(hash(global_mesh1), hash(global_mesh3))
def test_device_mismatch(self):
devices = jax.devices()
if len(devices) < 8:
raise unittest.SkipTest("Test requires 8 global devices.")
mesh_devices = np.array([[devices[0], devices[2]],
[devices[3], devices[1]],
[devices[4], devices[6]],
[devices[7], devices[5]]])
global_mesh = Mesh(mesh_devices, ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
indices = get_shard_indices(global_input_shape, global_mesh, mesh_axes)
dbs = [
jax.device_put(global_input_data[indices[d]], d)
for d in jax.local_devices()
]
with self.assertRaisesRegex(
ValueError,
'The `global_mesh.local_devices` and `device_buffers` device order'):
GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)
def test_gda_block_until_ready(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(('x', 'y'))
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
self.assertIs(gda.block_until_ready(), gda)
@parameterized.named_parameters(
("mesh_x_y", P("x", "y")),
("mesh_x", P("x")),
("mesh_y", P("y")),
("mesh_none_y", P(None, "y")),
("mesh_xy", P(("x", "y"))),
("mesh_fully_replicated", P()),
)
def test_gda_value(self, mesh_axes):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_shape = (8, 2)
gda, global_data = create_gda(input_shape, global_mesh, mesh_axes)
self.assertArraysEqual(gda._value, global_data)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())