Skip to content

Commit

Permalink
improve concreteness error message in remat
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Oct 24, 2024
1 parent b8bacda commit 4231128
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
20 changes: 9 additions & 11 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,17 +410,15 @@ def _trace_to_jaxpr(fun, in_tree, in_avals):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
except core.ConcretizationTypeError as e:
msg, = e.args
if 'for checkpoint' not in msg:
raise
new_msg = msg + "\n\n" + (
"Consider using the `static_argnums` parameter for `jax.remat` or "
"`jax.checkpoint`. See the `jax.checkpoint` docstring and its example "
"involving `static_argnums`:\n"
"https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html"
"\n")
new_e = core.ConcretizationTypeError.__new__(core.ConcretizationTypeError)
new_e.args = (new_msg,)
raise new_e from None
if 'for checkpoint' in msg:
msg += "\n\n" + (
"Consider using the `static_argnums` parameter for `jax.remat` or "
"`jax.checkpoint`. See the `jax.checkpoint` docstring and its example "
"involving `static_argnums`:\n"
"https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html"
"\n")
e.args = msg,
raise
return pe.convert_constvars_jaxpr(jaxpr), consts, out_tree()


Expand Down
16 changes: 16 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import re
import subprocess
import sys
import traceback
import types
from typing import NamedTuple
import unittest
Expand Down Expand Up @@ -6423,6 +6424,21 @@ def f(x):
y_, = vjp(jnp.ones_like(y))
self.assertAllClose(y, y_, atol=0, rtol=0)

def test_concreteness_error_includes_user_code(self):
@jax.remat
def f(x):
if x > 0:
return x
else:
return jnp.sin(x)

try:
f(3.)
except TracerBoolConversionError:
self.assertIn('x > 0', traceback.format_exc())
else:
assert False


@jtu.with_config(jax_pprint_use_color=False)
class JaxprTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 4231128

Please sign in to comment.