Skip to content

Commit 01194bd

Browse files
superbobryjax authors
authored and
jax authors
committed
Clarified the type of the inputs to callback APIs
The callback APIs were migrated to use jax.Arrays for both inputs and outputs in JAX 0.4.27. PiperOrigin-RevId: 634473890
1 parent 380503b commit 01194bd

File tree

3 files changed

+28
-8
lines changed

3 files changed

+28
-8
lines changed

docs/_tutorials/external-callbacks.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def f(x):
5656
result = f(2)
5757
```
5858

59-
This works by passing the runtime value represented by `y` back to the host process, where the host can print the value.
59+
This works by passing the runtime value of `y` as a CPU {class}`jax.Array` back to the host process, where the host can print it.
6060

6161
(external-callbacks-flavors-of-callback)=
6262
## Flavors of callback

jax/_src/callback.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,14 @@ def pure_callback_impl(
7272
vectorized: bool,
7373
):
7474
del sharding, vectorized, result_avals
75-
cpu_device, *_ = jax.local_devices(backend="cpu")
75+
try:
76+
cpu_device, *_ = jax.local_devices(backend="cpu")
77+
except RuntimeError as e:
78+
raise RuntimeError(
79+
"jax.pure_callback failed to find a local CPU device to place the"
80+
" inputs on. Make sure \"cpu\" is listed in --jax_platforms or the"
81+
" JAX_PLATFORMS environment variable."
82+
) from e
7683
args = jax.device_put(args, cpu_device)
7784
with jax.default_device(cpu_device):
7885
try:
@@ -262,9 +269,8 @@ def pure_callback(
262269
For more explanation, see `External Callbacks`_.
263270
264271
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
265-
The input ``callback`` will be passed NumPy arrays in place of JAX arrays and
266-
should also return NumPy arrays. Execution takes place on CPU, like any
267-
Python+NumPy function.
272+
The input ``callback`` will be passed JAX arrays placed on a local CPU, and
273+
it should also return JAX arrays on CPU.
268274
269275
The callback is treated as functionally pure, meaning it has no side-effects
270276
and its output value depends only on its argument values. As a consequence, it
@@ -357,7 +363,14 @@ def io_callback_impl(
357363
ordered: bool,
358364
):
359365
del result_avals, sharding, ordered
360-
cpu_device, *_ = jax.local_devices(backend="cpu")
366+
try:
367+
cpu_device, *_ = jax.local_devices(backend="cpu")
368+
except RuntimeError as e:
369+
raise RuntimeError(
370+
"jax.io_callback failed to find a local CPU device to place the"
371+
" inputs on. Make sure \"cpu\" is listed in --jax_platforms or the"
372+
" JAX_PLATFORMS environment variable."
373+
) from e
361374
args = jax.device_put(args, cpu_device)
362375
with jax.default_device(cpu_device):
363376
try:

jax/_src/debugging.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,20 @@ class OrderedDebugEffect(effects.Effect):
7777
def debug_callback_impl(*args, callback: Callable[..., Any],
7878
effect: DebugEffect):
7979
del effect
80-
cpu_device, *_ = jax.local_devices(backend="cpu")
80+
try:
81+
cpu_device, *_ = jax.local_devices(backend="cpu")
82+
except RuntimeError as e:
83+
raise RuntimeError(
84+
"jax.debug.callback failed to find a local CPU device to place the"
85+
" inputs on. Make sure \"cpu\" is listed in --jax_platforms or the"
86+
" JAX_PLATFORMS environment variable."
87+
) from e
8188
args = jax.device_put(args, cpu_device)
8289
with jax.default_device(cpu_device):
8390
try:
8491
callback(*args)
8592
except BaseException:
86-
logger.exception("jax.debug_callback failed")
93+
logger.exception("jax.debug.callback failed")
8794
raise
8895
return ()
8996

0 commit comments

Comments
 (0)