forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
transfer_guard_test.py
243 lines (207 loc) · 8.34 KB
/
transfer_guard_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for transfer guards."""
import contextlib
import pickle
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
import jax._src.test_util as jtu
import jax.numpy as jnp
from jax.config import config
config.parse_flags_with_absl()
def _host_to_device_funcs():
"""Generates host-to-device transfer functions."""
return [
# (function name, is an explicit transfer?, function)
("host_to_device_jax_device_put", True,
lambda: jax.device_put(np.ones(10))),
("host_to_device_jax_jit", False, lambda: jax.jit(lambda x: x)
(np.ones(1))),
("host_to_device_jnp_one", False, lambda: jnp.ones(1)),
]
def _device_to_device_funcs():
"""Generates device-to-device transfer functions."""
if len(jax.local_devices()) < 2:
# device-to-device tests require at least 2 devices.
return []
with jax.transfer_guard_host_to_device("allow"):
device_arrays = [jnp.ones(1) for _ in range(2)]
return [
# (function name, is an explicit transfer?, function)
("device_to_device_jax_device_put", True,
lambda: jax.device_put(device_arrays[0], device=jax.local_devices()[1])),
("device_to_device_jax_jit", False,
lambda: jax.jit(lambda x: x, device=jax.local_devices()[1])
(device_arrays[1])),
]
def _device_to_host_funcs():
"""Generates device-to-host transfer functions."""
if jax.default_backend() == "cpu":
# device-to-host does not incur transfer on the CPU backend.
return []
with jax.transfer_guard_host_to_device("allow"):
device_arrays = [jnp.ones(1) for _ in range(6)]
return [
# (function name, is an explicit transfer?, function)
("device_to_host_jax_device_get", True,
lambda: jax.device_get(device_arrays[0])),
("device_to_host_np_asarray", False,
lambda: np.asarray(device_arrays[1])),
("device_to_host_copy_to_host_async", False,
lambda: device_arrays[2].copy_to_host_async()),
("device_to_host_np_add", False, lambda: np.add(device_arrays[3], 1)),
("device_to_host_str", False, lambda: str(device_arrays[4])),
("device_to_host_pickle_dumps", False,
lambda: pickle.dumps(device_arrays[5])),
]
def _all_funcs():
"""Generates all transfer functions."""
return (_host_to_device_funcs() + _device_to_device_funcs() +
_device_to_host_funcs())
# List of test parameters shared by multiple tests.
_COMMON_TEST_PARAMETERS = [
("host_to_device", _host_to_device_funcs,
jax.transfer_guard_host_to_device),
("device_to_device", _device_to_device_funcs,
jax.transfer_guard_device_to_device),
("device_to_host", _device_to_host_funcs,
jax.transfer_guard_device_to_host),
("all", _all_funcs, jax.transfer_guard),
]
class TransferGuardTest(jtu.JaxTestCase):
# `_default_config` is used by `jtu.JaxTestCase` to update the JAX config for
# every test case. TransferGuardTest disables `--jax_enable_checks` because it
# can prematurely fetch the value of device arrays and make device-to-host
# tests to incur no transfers unexpectedly.
_default_config = {"jax_enable_checks": False}
@contextlib.contextmanager
def assertAllows(self, func_name):
"""Asserts that a transfer in the context is allowed."""
try:
yield
except Exception as e: # pylint: disable=broad-except
raise RuntimeError(
f"Expected a transfer to be allowed while running: {func_name}"
) from e
@contextlib.contextmanager
def assertLogs(self, func_name):
"""Asserts that a transfer in the context is logged and allowed."""
# Only check if the transfer is allowed until Abseil provides an interface
# to capture logs.
with self.assertAllows(func_name):
yield
@contextlib.contextmanager
def assertDisallows(self, func_name):
"""Asserts that a transfer in the context is disallowed."""
try:
with self.assertRaises(Exception):
yield
except Exception as e: # pylint: disable=broad-except
raise RuntimeError(
f"Expected a transfer to be disallowed while running: {func_name}"
) from e
def test_simple(self):
"""Simple transfer guard tests."""
with jax.transfer_guard("allow"):
with self.assertAllows("host_to_device_jnp_ones"):
jnp.ones(1)
with jax.transfer_guard("log"):
with self.assertLogs("host_to_device_jnp_ones"):
jnp.ones(1)
with jax.transfer_guard("disallow"):
with self.assertDisallows("host_to_device_jnp_ones"):
jnp.ones(1)
def test_nesting(self):
with jax.transfer_guard("disallow"):
with jax.transfer_guard("allow"):
with self.assertAllows("host_to_device_jnp_ones"):
jnp.ones(1)
with self.assertDisallows("host_to_device_jnp_ones"):
jnp.ones(1)
def test_mixed_nesting(self):
with jax.transfer_guard_host_to_device("disallow"):
with jax.transfer_guard("allow"):
with self.assertAllows("host_to_device_jnp_ones"):
jnp.ones(1)
with self.assertDisallows("host_to_device_jnp_ones"):
jnp.ones(1)
with jax.transfer_guard("disallow"):
with jax.transfer_guard_host_to_device("allow"):
with self.assertAllows("host_to_device_jnp_ones"):
jnp.ones(1)
with self.assertDisallows("host_to_device_jnp_ones"):
jnp.ones(1)
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
def test_allow_by_default(self, func_generator, _):
for func_name, _, func in func_generator():
with self.assertAllows(func_name):
func()
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
def test_allow(self, func_generator, jax_transfer_guard):
for func_name, _, func in func_generator():
with jax_transfer_guard("allow"):
with self.assertAllows(func_name):
func()
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
def test_log(self, func_generator, jax_transfer_guard):
for func_name, explicit, func in func_generator():
with jax_transfer_guard("log"):
if explicit:
with self.assertAllows(func_name):
func()
else:
with self.assertLogs(func_name):
func()
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
def test_disallow(self, func_generator, jax_transfer_guard):
for func_name, explicit, func in func_generator():
with jax_transfer_guard("disallow"):
if explicit:
with self.assertAllows(func_name):
func()
else:
with self.assertDisallows(func_name):
func()
@parameterized.named_parameters(
("device_to_host", _device_to_host_funcs,
jax.transfer_guard_device_to_host),
("all", _device_to_host_funcs, jax.transfer_guard),
)
def test_disallow_ignores_arrays_on_cpu(self, func_generator,
jax_transfer_guard):
for func_name, _, func in func_generator():
with jax_transfer_guard("allow"):
# Transfer the device array to host.
func()
with jax_transfer_guard("disallow"):
with self.assertAllows(func_name):
# No error because the array has a value on host and no new transfer
# will occur.
func()
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
def test_log_explicit(self, func_generator, jax_transfer_guard):
for func_name, _, func in func_generator():
with jax_transfer_guard("log_explicit"):
with self.assertLogs(func_name):
func()
@parameterized.named_parameters(*_COMMON_TEST_PARAMETERS)
def test_disallow_explicit(self, func_generator, jax_transfer_guard):
for func_name, _, func in func_generator():
with jax_transfer_guard("disallow_explicit"):
with self.assertDisallows(func_name):
func()
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())