forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pickle_test.py
200 lines (157 loc) · 6.22 KB
/
pickle_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for interoperability between JAX and pickling libraries."""
import pickle
import unittest
from absl.testing import absltest
from absl.testing import parameterized
try:
import cloudpickle
except ImportError:
cloudpickle = None
import jax
from jax import numpy as jnp
from jax import config
from jax.interpreters import pxla
from jax.interpreters import xla
from jax._src import test_util as jtu
from jax._src.lib import xla_client as xc
import numpy as np
config.parse_flags_with_absl()
def _get_device_by_id(device_id: int) -> xc.Device:
for device in jax.devices():
if device.id == device_id:
return device
raise ValueError(f'Device {device_id} was not found')
xc.Device.__reduce__ = lambda d: (_get_device_by_id, (d.id,))
if cloudpickle is not None:
def _reduce_mesh(mesh):
# Avoid including mesh._hash in the serialized bytes for Mesh. Without this
# the Mesh would be different among the workers.
return jax.sharding.Mesh, (mesh.devices, mesh.axis_names)
cloudpickle.CloudPickler.dispatch_table[jax.sharding.Mesh] = _reduce_mesh
class CloudpickleTest(jtu.JaxTestCase):
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
def testPickleOfJittedFunctions(self):
@jax.jit
def f(x, y):
return x * y
@jax.jit
def g(z):
return f(z, z + 77) # noqa: F821
expected = g(32)
s = cloudpickle.dumps(g)
del f, g
g_unpickled = pickle.loads(s)
actual = g_unpickled(32)
self.assertEqual(expected, actual)
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
def testPickleOfPmappedFunctions(self):
@jax.pmap
def f(x, y):
return x * y
@jax.pmap
def g(z):
return f(z, z + 77) # noqa: F821
expected = g(jnp.asarray([[32]]))
s = cloudpickle.dumps(g)
del f, g
g_unpickled = pickle.loads(s)
actual = g_unpickled(jnp.asarray([[32]]))
self.assertEqual(expected, actual)
class PickleTest(jtu.JaxTestCase):
def testPickleOfDeviceArray(self):
x = jnp.arange(10.0)
s = pickle.dumps(x)
y = pickle.loads(s)
self.assertArraysEqual(x, y)
self.assertIsInstance(y, type(x))
self.assertEqual(x.aval, y.aval)
def testPickleOfDeviceArrayWeakType(self):
x = jnp.array(4.0)
self.assertEqual(x.aval.weak_type, True)
s = pickle.dumps(x)
y = pickle.loads(s)
self.assertArraysEqual(x, y)
self.assertIsInstance(y, type(x))
self.assertEqual(x.aval, y.aval)
@jtu.sample_product(prng_name=['threefry2x32', 'rbg', 'unsafe_rbg'])
def testPickleOfKeyArray(self, prng_name):
with jax.default_prng_impl(prng_name):
k1 = jax.random.PRNGKey(72)
s = pickle.dumps(k1)
k2 = pickle.loads(s)
self.assertEqual(k1.dtype, k2.dtype)
self.assertArraysEqual(jax.random.key_data(k1),
jax.random.key_data(k2))
@parameterized.parameters(
(jax.sharding.PartitionSpec(),),
(jax.sharding.PartitionSpec(None),),
(jax.sharding.PartitionSpec('x', None),),
(jax.sharding.PartitionSpec(None, 'y'),),
(jax.sharding.PartitionSpec('x', 'y'),),
(jax.sharding.PartitionSpec(('x', 'y'),),),
)
def testPickleOfPartitionSpecs(self, partition_spec):
restored_partition_spec = pickle.loads(pickle.dumps(partition_spec))
self.assertIsInstance(restored_partition_spec, jax.sharding.PartitionSpec)
self.assertTupleEqual(partition_spec, restored_partition_spec)
def testPickleX64(self):
with jax.experimental.enable_x64():
x = jnp.array(4.0, dtype='float64')
s = pickle.dumps(x)
with jax.experimental.disable_x64():
y = pickle.loads(s)
self.assertEqual(x.dtype, jnp.float64)
self.assertArraysEqual(x, y, check_dtypes=False)
self.assertEqual(y.dtype, jnp.float32)
self.assertEqual(y.aval.dtype, jnp.float32)
self.assertIsInstance(y, type(x))
def testPickleTracerError(self):
with self.assertRaises(jax.errors.ConcretizationTypeError):
jax.jit(pickle.dumps)(0)
def testPickleSharding(self):
sharding = pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked(
(2, 2)), pxla.Unstacked(3)), (pxla.ShardedAxis(0), pxla.ShardedAxis(1),
pxla.ShardedAxis(2), pxla.Replicated(4)))
self.assertEqual(pickle.loads(pickle.dumps(sharding)), sharding)
def testPickleOpSharding(self):
sharding = pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
(pxla.ShardedAxis(0), pxla.ShardedAxis(1)))
op_sharding = sharding.sharding_proto()
self.assertTrue(
xc.HloSharding.from_proto(pickle.loads(pickle.dumps(op_sharding))),
xc.HloSharding.from_proto(op_sharding))
def test_pickle_single_device_sharding(self):
s = jax.sharding.SingleDeviceSharding(jax.devices()[0])
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
def test_pickle_pmap_sharding(self):
ss = pxla.ShardingSpec(
sharding=(pxla.Unstacked(8),),
mesh_mapping=(pxla.ShardedAxis(0),))
s = jax.sharding.PmapSharding(jax.devices(), ss)
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
def test_pickle_gspmd_sharding(self):
op_sharding = xla.xc.OpSharding()
op_sharding.type = xla.xc.OpSharding.Type.REPLICATED
s = jax.sharding.GSPMDSharding(jax.devices(), op_sharding)
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
def test_pickle_named_sharding(self):
s = jax.sharding.NamedSharding(
mesh=jax.sharding.Mesh(np.array(jax.devices()), 'd'),
spec=jax.sharding.PartitionSpec('d'))
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())