forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
jax_jit_test.py
223 lines (184 loc) · 8.41 KB
/
jax_jit_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import inspect
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import api
from jax._src import core
from jax import dtypes
from jax._src import lib as jaxlib
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax.config import config
import numpy as np
config.parse_flags_with_absl()
def _cpp_device_put(value, device):
return jaxlib.jax_jit.device_put(value, config.x64_enabled, device)
class JaxJitTest(jtu.JaxTestCase):
def test_is_float_0(self):
self.assertTrue(
jaxlib.jax_jit._is_float0(np.zeros((5, 5), dtype=jax.float0)))
self.assertFalse(jaxlib.jax_jit._is_float0(np.zeros((5, 5))))
@parameterized.parameters([jax.device_put, _cpp_device_put])
def test_device_put_on_numpy_masked_array(self, device_put_function):
# TODO(jakevdp): add appropriate logic to jaxlib device_put and update this test.
if device_put_function is _cpp_device_put:
self.skipTest("cpp device_put does not yet reject masked arrays.")
device = jax.devices()[0]
value = np.ma.array([1, 2, 3], mask=[True, False, True])
with self.assertRaisesRegex(ValueError, "numpy masked arrays are not supported"):
device_put_function(value, device=device)
@parameterized.parameters([jax.device_put, _cpp_device_put])
def test_device_put_on_numpy_scalars(self, device_put_function):
device = jax.devices()[0]
for dtype in jtu.supported_dtypes():
value = dtype(0)
output_buffer = device_put_function(value, device=device)
self.assertFalse(output_buffer.aval.weak_type)
dtype = dtypes.canonicalize_dtype(dtype)
self.assertEqual(output_buffer.aval, core.ShapedArray((), dtype))
self.assertEqual(output_buffer.dtype, dtype)
@parameterized.parameters([jax.device_put, _cpp_device_put])
def test_device_put_on_numpy_arrays(self, device_put_function):
device = jax.devices()[0]
for dtype in jtu.supported_dtypes():
value = np.zeros((3, 4), dtype=dtype)
output_buffer = device_put_function(value, device=device)
self.assertFalse(output_buffer.aval.weak_type)
dtype = dtypes.canonicalize_dtype(dtype)
self.assertEqual(output_buffer.aval, core.ShapedArray((3, 4), dtype))
self.assertEqual(output_buffer.dtype, dtype)
np.testing.assert_array_equal(output_buffer, np.zeros((3, 4),
dtype=dtype))
@parameterized.parameters([jax.device_put, _cpp_device_put])
def test_device_put_on_buffers(self, device_put_function):
device = jax.devices()[0]
jitted_f = jax.jit(lambda x: x + 1)
# We run it twice, to cover `_DeviceArray` and the C++ `Buffer`.
for value in range(2):
buffer = jitted_f(value)
output_buffer = device_put_function(buffer, device=device)
self.assertEqual(output_buffer.dtype, buffer.dtype)
self.assertEqual(output_buffer.aval, buffer.aval)
np.testing.assert_array_equal(output_buffer, np.array(value + 1))
@parameterized.parameters([jax.device_put, _cpp_device_put])
def test_device_put_on_sharded_device_array(self, device_put_function):
device = jax.devices()[0]
pmaped_f = jax.pmap(lambda x: x + 1)
for _ in range(2):
sda = pmaped_f(np.asarray([[1]]))
output_buffer = device_put_function(sda, device=device)
self.assertNotIsInstance(output_buffer,
jax.interpreters.pxla.ShardedDeviceArray)
self.assertEqual(output_buffer.dtype, sda.dtype)
self.assertEqual(output_buffer.aval, sda.aval)
np.testing.assert_array_equal(output_buffer, np.asarray(sda))
def test_device_put_on_python_scalars(self):
device = jax.devices()[0]
int_type = dtypes.canonicalize_dtype(np.int64)
float_type = dtypes.canonicalize_dtype(np.float64)
complex_type = dtypes.canonicalize_dtype(np.complex128)
# int
res = np.asarray(_cpp_device_put(1, device))
self.assertEqual(res, 1)
self.assertEqual(res.dtype, int_type)
# We also compare to the Python Jax API, to make sure we have the exact
# same behavior. When Jax removes the flag and removes this feature, this
# test will fail.
self.assertEqual(jnp.asarray(1).dtype, res.dtype)
# float
res = np.asarray(_cpp_device_put(1.0, device))
self.assertEqual(res, 1.0)
self.assertEqual(res.dtype, float_type)
self.assertEqual(jnp.asarray(1.0).dtype, res.dtype)
# bool
for bool_value in [True, False]:
res = np.asarray(_cpp_device_put(bool_value, device))
self.assertEqual(res, np.asarray(bool_value))
self.assertEqual(res.dtype, np.bool_)
self.assertEqual(jnp.asarray(bool_value).dtype, res.dtype)
# Complex
if not (config.x64_enabled and jtu.device_under_test() == "tpu"):
# No TPU support for complex128.
res = np.asarray(_cpp_device_put(1 + 1j, device))
self.assertEqual(res, 1 + 1j)
self.assertEqual(res.dtype, complex_type)
self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype)
def test_convert_int_overflow(self):
with self.assertRaisesRegex(
RuntimeError,
"(Python int too large|Unable to convert Python scalar).*"):
jaxlib.jax_jit.device_put(int(1e100), True, jax.devices()[0])
def test_arg_signature_of_value(self):
"""Tests the C++ code-path."""
jax_enable_x64 = config.x64_enabled
# 1. Numpy scalar types
for dtype in jtu.supported_dtypes():
value = dtype(0)
signature = jaxlib.jax_jit._ArgSignatureOfValue(value, jax_enable_x64)
self.assertEqual(signature.dtype, jax.device_put(value).dtype)
self.assertEqual(signature.shape, ())
self.assertFalse(signature.weak_type)
# 2. Numpy arrays
for dtype in jtu.supported_dtypes():
value = np.zeros((3, 4), dtype=dtype)
signature = jaxlib.jax_jit._ArgSignatureOfValue(value, jax_enable_x64)
self.assertEqual(signature.dtype, jax.device_put(value).dtype)
self.assertEqual(signature.shape, (3, 4))
self.assertFalse(signature.weak_type)
int_type = dtypes.canonicalize_dtype(np.int64)
float_type = dtypes.canonicalize_dtype(np.float64)
complex_type = dtypes.canonicalize_dtype(np.complex128)
# 3. Python scalar types
# int
signature = jaxlib.jax_jit._ArgSignatureOfValue(1, jax_enable_x64)
self.assertEqual(signature.dtype, jax.device_put(1).dtype)
self.assertEqual(signature.dtype, int_type)
self.assertEqual(signature.shape, ())
self.assertTrue(signature.weak_type)
# float
signature = jaxlib.jax_jit._ArgSignatureOfValue(1.0, jax_enable_x64)
self.assertEqual(signature.dtype, jax.device_put(1.0).dtype)
self.assertEqual(signature.dtype, float_type)
self.assertEqual(signature.shape, ())
self.assertTrue(signature.weak_type)
# bool
for bool_value in [True, False]:
signature = jaxlib.jax_jit._ArgSignatureOfValue(bool_value,
jax_enable_x64)
self.assertEqual(signature.dtype, jax.device_put(bool_value).dtype)
self.assertEqual(signature.dtype, np.bool_)
self.assertEqual(signature.shape, ())
self.assertTrue(signature.weak_type)
# Complex
if not (jax_enable_x64 and jtu.device_under_test() == "tpu"):
# No TPU support for complex128.
signature = jaxlib.jax_jit._ArgSignatureOfValue(1 + 1j, jax_enable_x64)
self.assertEqual(signature.dtype, jax.device_put(1 + 1j).dtype)
self.assertEqual(signature.dtype, complex_type)
self.assertEqual(signature.shape, ())
self.assertTrue(signature.weak_type)
def test_signature_support(self):
if jax.config.jax_jit_pjit_api_merge:
jit = jax.jit
else:
jit = partial(api._jit, True)
def f(a, b, c):
return a + b + c
jitted_f = jit(f)
self.assertEqual(inspect.signature(f), inspect.signature(jitted_f))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())