Skip to content

Commit

Permalink
add experimental jax_log_checkpoint_residuals option
Browse files Browse the repository at this point in the history
The main idea here is to improve tooling for knowing what residuals are being
saved and why. There's a lot more that can be done here (e.g. naming the
arguments, explaining what JVP rule produced these residuals, explaining what
consumed them, etc) but this is a start.

Co-authored-by: Qiao Zhang <[email protected]>
  • Loading branch information
mattjj and zhangqiaorjc committed Mar 22, 2023
1 parent 64e1f5f commit 6b4262d
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 3 deletions.
32 changes: 29 additions & 3 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from functools import partial
import logging
from typing import (Any, Callable, FrozenSet, List, Optional, Sequence, Tuple,
Union)
import types
Expand Down Expand Up @@ -47,6 +48,8 @@
map = safe_map
zip = safe_zip

logger = logging.getLogger(__name__)

allowed_effects: effects.EffectTypeSet = effects.remat_allowed_effects

### Policies
Expand Down Expand Up @@ -392,6 +395,12 @@ def f_(*args):
jaxpr_, out_shape = out
jaxpr = jaxpr_.jaxpr
out_tree = lambda: tree_structure(out_shape)
assert len(jaxpr.invars) == len(in_leaves)
dbg = pe.debug_info(f, in_tree, out_tree, True, "saved_residuals")
arg_info = pe.arg_info_all(dbg)
return _saved_residuals(jaxpr, arg_info)

def _saved_residuals(jaxpr, arg_info) -> List[Tuple[core.AbstractValue, str]]:
res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}

Expand All @@ -404,9 +413,6 @@ def f_(*args):
if v in res_vars:
results.append((v.aval, 'from a constant'))

assert len(jaxpr.invars) == len(in_leaves)
dbg = pe.debug_info(f, in_tree, out_tree, True, "saved_residuals")
arg_info = pe.arg_info_all(dbg)
for i, v in enumerate(jaxpr.invars):
if v in res_vars:
if arg_info is not None:
Expand Down Expand Up @@ -509,6 +515,26 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
new_params, jaxpr_unknown.effects,
source_info_util.current())

# log info about saved residuals
try:
_, staged_unk = partition_list(in_used_staged, in_unknowns)
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
logger.log(logging.WARNING if jax.config.jax_log_checkpoint_residuals
else logging.DEBUG,
'remat-decorated function ' +
'saving inputs with shapes:\n' * bool(res_invars) +
' %s\n' * len(res_invars) +
'and ' * bool(res_invars) * bool(body_res) +
'saving these intermediates:\n' * bool(body_res) +
' %s from %s\n' * len(body_res),
*[v.aval.str_short() for v in res_invars],
*[elt for (a, s) in body_res for elt in [a.str_short(), s]])
except:
pass # just don't log anything on failure

for t in out_jaxpr_tracers: t.recipe = recipe

# zip together known and unknown outputs
Expand Down
7 changes: 7 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,13 @@ def update_thread_local_jit_state(**kw):
'option is set, the log level is WARNING; otherwise the level is '
'DEBUG.'))

log_compiles = config.define_bool_state(
name='jax_log_checkpoint_residuals',
default=False,
help=('Log a message every time jax.checkpoint (aka jax.remat) is '
'partially evaluated (e.g. for autodiff), printing what residuals '
'are saved.'))

parallel_functions_output_gda = config.define_bool_state(
name='jax_parallel_functions_output_gda',
default=False,
Expand Down
50 changes: 50 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5715,6 +5715,56 @@ def f(x):
self.assertIn(' sin ', str(jaxpr))
self.assertIn(' cos ', str(jaxpr))

def test_remat_residual_logging(self):
def f(x):
x = jnp.sin(x)
x = jnp.cos(x.sum())
return x

x = jnp.arange(3.)

f1 = jax.remat(f)
f2 = jax.remat(f, policy=lambda *_, **__: True)
f3 = jax.remat(f, policy=lambda p, *_, **__: str(p) == 'cos')

prev_level = logging.get_verbosity()
try:
logging.set_verbosity('DEBUG')
with self.assertLogs(level=logging.DEBUG) as l:
jax.grad(f1)(x)
finally:
logging.set_verbosity(prev_level)
self.assertTrue(any('remat-decorated function saving inputs with shapes:'
in line for line in l.output))
self.assertFalse(any('intermediates' in line for line in l.output))

prev_level = logging.get_verbosity()
try:
logging.set_verbosity('DEBUG')
with self.assertLogs(level=logging.DEBUG) as l:
jax.grad(f2)(x)
finally:
logging.set_verbosity(prev_level)
self.assertFalse(any('saving inputs' in line for line in l.output))
self.assertTrue(any('remat-decorated function saving these intermediates:'
in line for line in l.output))
self.assertTrue(any(' sin ' in line for line in l.output))
self.assertTrue(any(' cos ' in line for line in l.output))

prev_level = logging.get_verbosity()
try:
logging.set_verbosity('DEBUG')
with self.assertLogs(level=logging.DEBUG) as l:
jax.grad(f3)(x)
finally:
logging.set_verbosity(prev_level)
self.assertTrue(any('remat-decorated function saving inputs with shapes:'
in line for line in l.output))
self.assertTrue(any('and saving these intermediates:'
in line for line in l.output))
self.assertFalse(any(' sin ' in line for line in l.output))
self.assertTrue(any(' cos ' in line for line in l.output))


class JaxprTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 6b4262d

Please sign in to comment.