@@ -72,7 +72,14 @@ def pure_callback_impl(
72
72
vectorized : bool ,
73
73
):
74
74
del sharding , vectorized , result_avals
75
- cpu_device , * _ = jax .local_devices (backend = "cpu" )
75
+ try :
76
+ cpu_device , * _ = jax .local_devices (backend = "cpu" )
77
+ except RuntimeError as e :
78
+ raise RuntimeError (
79
+ "jax.pure_callback failed to find a local CPU device to place the"
80
+ " inputs on. Make sure \" cpu\" is listed in --jax_platforms or the"
81
+ " JAX_PLATFORMS environment variable."
82
+ ) from e
76
83
args = jax .device_put (args , cpu_device )
77
84
with jax .default_device (cpu_device ):
78
85
try :
@@ -262,9 +269,8 @@ def pure_callback(
262
269
For more explanation, see `External Callbacks`_.
263
270
264
271
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
265
- The input ``callback`` will be passed NumPy arrays in place of JAX arrays and
266
- should also return NumPy arrays. Execution takes place on CPU, like any
267
- Python+NumPy function.
272
+ The input ``callback`` will be passed JAX arrays placed on a local CPU, and
273
+ it should also return JAX arrays on CPU.
268
274
269
275
The callback is treated as functionally pure, meaning it has no side-effects
270
276
and its output value depends only on its argument values. As a consequence, it
@@ -357,7 +363,14 @@ def io_callback_impl(
357
363
ordered : bool ,
358
364
):
359
365
del result_avals , sharding , ordered
360
- cpu_device , * _ = jax .local_devices (backend = "cpu" )
366
+ try :
367
+ cpu_device , * _ = jax .local_devices (backend = "cpu" )
368
+ except RuntimeError as e :
369
+ raise RuntimeError (
370
+ "jax.io_callback failed to find a local CPU device to place the"
371
+ " inputs on. Make sure \" cpu\" is listed in --jax_platforms or the"
372
+ " JAX_PLATFORMS environment variable."
373
+ ) from e
361
374
args = jax .device_put (args , cpu_device )
362
375
with jax .default_device (cpu_device ):
363
376
try :
0 commit comments