forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathxla_bridge_test.py
240 lines (191 loc) · 8 KB
/
xla_bridge_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import warnings
from absl.testing import absltest
from jax._src import test_util as jtu
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax.interpreters import xla
from jax._src.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
mock = absltest.mock
class XlaBridgeTest(jtu.JaxTestCase):
def test_set_device_assignment_no_partition(self):
compile_options = xb.get_compile_options(
num_replicas=4, num_partitions=1, device_assignment=[0, 1, 2, 3])
expected_device_assignment = ("Computations: 1 Replicas: 4\nComputation 0: "
"0 1 2 3 \n")
self.assertEqual(compile_options.device_assignment.__repr__(),
expected_device_assignment)
def test_set_device_assignment_with_partition(self):
compile_options = xb.get_compile_options(
num_replicas=2, num_partitions=2, device_assignment=[[0, 1], [2, 3]])
expected_device_assignment = ("Computations: 2 Replicas: 2\nComputation 0: "
"0 2 \nComputation 1: 1 3 \n")
self.assertEqual(compile_options.device_assignment.__repr__(),
expected_device_assignment)
def test_parameter_replication_default(self):
c = xc.XlaBuilder("test")
_ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()))
built_c = c.Build()
assert "replication" not in built_c.as_hlo_text()
def test_parameter_replication(self):
c = xc.XlaBuilder("test")
_ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()), "",
False)
built_c = c.Build()
assert "parameter_replication={false}" in built_c.as_hlo_text()
def test_local_devices(self):
self.assertNotEmpty(xb.local_devices())
with self.assertRaisesRegex(ValueError, "Unknown process_index 100"):
xb.local_devices(100)
with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"):
xb.local_devices(backend="foo")
def test_timer_tpu_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
def _mock_tpu_client():
time_to_wait = 5
start = time.time()
while not w:
if time.time() - start > time_to_wait:
raise ValueError(
"This test should not hang for more than "
f"{time_to_wait} seconds.")
time.sleep(0.1)
self.assertLen(w, 1)
msg = str(w[-1].message)
self.assertIn("Did you run your code on all TPU hosts?", msg)
with mock.patch("jax._src.lib.xla_client.make_tpu_client",
side_effect=_mock_tpu_client):
xb.tpu_client_timer_callback(0.01)
class GetBackendTest(jtu.JaxTestCase):
class _DummyBackend:
def __init__(self, platform, device_count):
self.platform = platform
self._device_count = device_count
def device_count(self):
return self._device_count
def process_index(self):
return 0
def local_devices(self):
return []
def _register_factory(self, platform: str, priority, device_count=1):
xb.register_backend_factory(
platform, lambda: self._DummyBackend(platform, device_count),
priority=priority)
def setUp(self):
self._orig_factories = xb._backend_factories
xb._backend_factories = {}
self._save_backend_state()
self._reset_backend_state()
# get_backend logic assumes CPU platform is always present.
self._register_factory("cpu", 0)
def tearDown(self):
xb._backend_factories = self._orig_factories
self._restore_backend_state()
def _save_backend_state(self):
self._orig_backends = xb._backends
self._orig_backends_errors = xb._backends_errors
self._orig_default_backend = xb._default_backend
def _reset_backend_state(self):
xb._backends = {}
xb._backends_errors = {}
xb._default_backend = None
xb.get_backend.cache_clear()
def _restore_backend_state(self):
xb._backends = self._orig_backends
xb._backends_errors = self._orig_backends_errors
xb._default_backend = self._orig_default_backend
xb.get_backend.cache_clear()
def test_default(self):
self._register_factory("platform_A", 20)
self._register_factory("platform_B", 10)
backend = xb.get_backend()
self.assertEqual(backend.platform, "platform_A")
# All backends initialized.
self.assertEqual(len(xb._backends), len(xb._backend_factories))
def test_specific_platform(self):
self._register_factory("platform_A", 20)
self._register_factory("platform_B", 10)
backend = xb.get_backend("platform_B")
self.assertEqual(backend.platform, "platform_B")
# All backends initialized.
self.assertEqual(len(xb._backends), len(xb._backend_factories))
def test_unknown_backend_error(self):
with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"):
xb.get_backend("foo")
def test_backend_init_error(self):
def factory():
raise RuntimeError("I'm not a real backend")
xb.register_backend_factory("error", factory, priority=10)
# No error raised if there's a fallback backend.
default_backend = xb.get_backend()
self.assertEqual(default_backend.platform, "cpu")
with self.assertRaisesRegex(RuntimeError, "I'm not a real backend"):
xb.get_backend("error")
def test_no_devices(self):
self._register_factory("no_devices", -10, device_count=0)
default_backend = xb.get_backend()
self.assertEqual(default_backend.platform, "cpu")
with self.assertRaisesRegex(
RuntimeError,
"Backend 'no_devices' failed to initialize: "
"Backend 'no_devices' provides no devices."):
xb.get_backend("no_devices")
self._reset_backend_state()
self._register_factory("no_devices2", 10, device_count=0)
default_backend = xb.get_backend()
self.assertEqual(default_backend.platform, "cpu")
with self.assertRaisesRegex(
RuntimeError,
"Backend 'no_devices2' failed to initialize: "
"Backend 'no_devices2' provides no devices."):
xb.get_backend("no_devices2")
def test_factory_returns_none(self):
xb.register_backend_factory("none", lambda: None, priority=10)
default_backend = xb.get_backend()
self.assertEqual(default_backend.platform, "cpu")
with self.assertRaisesRegex(
RuntimeError,
"Backend 'none' failed to initialize: "
"Could not initialize backend 'none'"):
xb.get_backend("none")
def cpu_fallback_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
xb.get_backend()
self.assertLen(w, 1)
msg = str(w[-1].message)
self.assertIn("No GPU/TPU found, falling back to CPU", msg)
def test_jax_platforms_flag(self):
self._register_factory("platform_A", 20)
self._register_factory("platform_B", 10)
orig_jax_platforms = config._read("jax_platforms")
try:
config.FLAGS.jax_platforms = "cpu,platform_A"
backend = xb.get_backend()
self.assertEqual(backend.platform, "cpu")
# Only specified backends initialized.
self.assertEqual(len(xb._backends), 2)
backend = xb.get_backend("platform_A")
self.assertEqual(backend.platform, "platform_A")
with self.assertRaisesRegex(RuntimeError, "Unknown backend platform_B"):
backend = xb.get_backend("platform_B")
finally:
config.FLAGS.jax_platforms = orig_jax_platforms
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())