Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has summaries and you can click the "Read more" links at the bottom to learn more.
Table of contents:
- Interactive inspection with
jax.debug
- Functional error checks with jax.experimental.checkify
- Throwing Python errors with JAX’s debug flags
Complete guide here
Summary: Use {func}jax.debug.print
to print values to stdout in jax.jit
-,jax.pmap
-, and pjit
-decorated functions,
and {func}jax.debug.breakpoint
to pause execution of your compiled function to inspect values in the call stack:
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("🤯 {x} 🤯", x=x)
y = jnp.sin(x)
jax.debug.breakpoint()
jax.debug.print("🤯 {y} 🤯", y=y)
return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# Enters breakpoint to inspect values!
# 🤯 0.9092974662780762 🤯
Complete guide here
Summary: Checkify lets you add jit
-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the checkify.checkify
transformation together with the assert-like checkify.check
function to add runtime checks to JAX code:
from jax.experimental import checkify
import jax
import jax.numpy as jnp
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
y = x[i]
z = jnp.sin(y)
return z
jittable_f = checkify.checkify(f)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
print(err.get())
# >> index needs to be non-negative! (check failed at <...>:6 (f))
You can also use checkify to automatically add common checks:
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
Complete guide here
Summary: Enable the jax_debug_nans
flag to automatically detect when NaNs are produced in jax.jit
-compiled code (but not in jax.pmap
or jax.pjit
-compiled code) and enable the jax_disable_jit
flag to disable JIT-compilation, enabling use of traditional Python debugging tools like print
and pdb
.
import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
return x / y
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
:caption: Read more
:maxdepth: 1
print_breakpoint
checkify_guide
flags