Skip to content

Files

Latest commit

672a013 · Aug 28, 2024

History

History
78 lines (55 loc) · 2.88 KB

flags.md

File metadata and controls

78 lines (55 loc) · 2.88 KB

JAX debugging flags

JAX offers flags and context managers that enable catching errors more easily.

jax_debug_nans configuration option and context manager

Summary: Enable the jax_debug_nans flag to automatically detect when NaNs are produced in jax.jit-compiled code (but not in jax.pmap or jax.pjit-compiled code).

jax_debug_nans is a JAX flag that when enabled, automatically raises an error when a NaN is detected. It has special handling for JIT-compiled -- when a NaN output is detected from a JIT-ted function, the function is re-run eagerly (i.e. without compilation) and will throw an error at the specific primitive that produced the NaN.

Usage

If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:

  • setting the JAX_DEBUG_NANS=True environment variable;
  • adding jax.config.update("jax_debug_nans", True) near the top of your main file;
  • adding jax.config.parse_flags_with_absl() to your main file, then set the option using a command-line flag like --jax_debug_nans=True;

Example(s)

import jax
jax.config.update("jax_debug_nans", True)

def f(x, y):
  return x / y
jax.jit(f)(0., 0.)  # ==> raises FloatingPointError exception!

Strengths and limitations of jax_debug_nans

Strengths
  • Easy to apply
  • Precisely detects where NaNs were produced
  • Throws a standard Python exception and is compatible with PDB postmortem
Limitations
  • Not compatible with jax.pmap or jax.pjit
  • Re-running functions eagerly can be slow
  • Errors on false positives (e.g. intentionally created NaNs)

jax_disable_jit configuration option and context manager

Summary: Enable the jax_disable_jit flag to disable JIT-compilation, enabling use of traditional Python debugging tools like print and pdb

jax_disable_jit is a JAX flag that when enabled, disables JIT-compilation throughout JAX (including in control flow functions like jax.lax.cond and jax.lax.scan).

Usage

You can disable JIT-compilation by:

  • setting the JAX_DISABLE_JIT=True environment variable;
  • adding jax.config.update("jax_disable_jit", True) near the top of your main file;
  • adding jax.config.parse_flags_with_absl() to your main file, then set the option using a command-line flag like --jax_disable_jit=True;

Examples

import jax
jax.config.update("jax_disable_jit", True)

def f(x):
  y = jnp.log(x)
  if jnp.isnan(y):
    breakpoint()
  return y
jax.jit(f)(-2.)  # ==> Enters PDB breakpoint!

Strengths and limitations of jax_disable_jit

Strengths
  • Easy to apply
  • Enables use of Python's built-in breakpoint and print
  • Throws standard Python exceptions and is compatible with PDB postmortem
Limitations
  • Not compatible with jax.pmap or jax.pjit
  • Running functions without JIT-compilation can be slow