Skip to content

Commit

Permalink
Avoid imports from the public jax.* namespace in more places internally.
Browse files Browse the repository at this point in the history
This change is in preparation for more cycle breaking in the Bazel dependency graph.

PiperOrigin-RevId: 521822756
  • Loading branch information
hawkinsp authored and jax authors committed Apr 4, 2023
1 parent 3c1f3ab commit c1f65fc
Show file tree
Hide file tree
Showing 24 changed files with 486 additions and 480 deletions.
6 changes: 3 additions & 3 deletions jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@

from jax._src.api import effects_barrier as effects_barrier
from jax._src.api import block_until_ready as block_until_ready
from jax._src.api import checkpoint as checkpoint
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
from jax._src.api import clear_backends as clear_backends
from jax._src.custom_derivatives import closure_convert as closure_convert
Expand Down Expand Up @@ -116,8 +116,8 @@
from jax._src.api import pmap as pmap
from jax._src.xla_bridge import process_count as process_count
from jax._src.xla_bridge import process_index as process_index
from jax._src.api import pure_callback as pure_callback
from jax._src.api import remat as remat
from jax._src.callback import pure_callback_api as pure_callback
from jax._src.ad_checkpoint import checkpoint_wrapper as remat
from jax._src.core import ShapedArray as _deprecated_ShapedArray
from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
from jax._src.api import value_and_grad as value_and_grad
Expand Down
90 changes: 78 additions & 12 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from functools import partial
import logging
from typing import (Any, Callable, FrozenSet, List, Optional, Sequence, Tuple,
Expand All @@ -20,9 +21,8 @@

import numpy as np

import jax
from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr
from jax._src import ad_util
from jax._src import api
from jax._src import core
from jax._src import dispatch
from jax._src import linear_util as lu
Expand All @@ -31,6 +31,7 @@
from jax._src import traceback_util
from jax._src import util
from jax._src.api_util import flatten_fun, shaped_abstractify
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
Expand All @@ -39,6 +40,7 @@
from jax._src.lax import convolution as lax_convolution
from jax._src.lib.mlir.dialects import hlo
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
safe_zip, merge_lists, weakref_lru_cache)

Expand Down Expand Up @@ -389,7 +391,7 @@ def f_(*args):
args, kwargs = tree_unflatten(in_tree, args)
return f(*args, **kwargs)

out = jax.make_jaxpr(lambda *args: jax.linearize(f_, *args)[1],
out = api.make_jaxpr(lambda *args: api.linearize(f_, *args)[1],
return_shape=True)(*in_leaves)
assert isinstance(out, tuple)
jaxpr_, out_shape = out
Expand Down Expand Up @@ -522,7 +524,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params):
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
logger.log(logging.WARNING if jax.config.jax_log_checkpoint_residuals
logger.log(logging.WARNING if config.jax_log_checkpoint_residuals
else logging.DEBUG,
'remat-decorated function ' +
'saving inputs with shapes:\n' * bool(res_invars) +
Expand Down Expand Up @@ -652,7 +654,7 @@ def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
assert not jaxpr.constvars

if differentiated and prevent_cse:
if jax.config.jax_remat_opt_barrier:
if config.jax_remat_opt_barrier:
translation_rule = _remat_translation_using_opt_barrier
elif is_gpu_platform:
translation_rule = _remat_translation_using_while
Expand All @@ -661,7 +663,7 @@ def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
else:
translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args)

return jax.named_call(translation_rule, name="remat")(*args, jaxpr=jaxpr)
return api.named_call(translation_rule, name="remat")(*args, jaxpr=jaxpr)

def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
args = _optimization_barrier(args)
Expand All @@ -670,9 +672,9 @@ def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
# TODO(mattjj): add core utility for 'create dummy value for this type'?
def _dummy_like(aval: core.AbstractValue) -> Any:
if aval is core.abstract_token:
return jax.lax.create_token()
return lax_internal.create_token()
elif isinstance(aval, (core.ShapedArray, core.DShapedArray)):
return jax.lax.broadcast(lax_internal.empty(aval.dtype), aval.shape) # type: ignore
return lax_internal.broadcast(lax_internal.empty(aval.dtype), aval.shape) # type: ignore
else:
raise ValueError(aval)

Expand All @@ -682,19 +684,21 @@ def _remat_translation_using_while(*args, jaxpr: core.Jaxpr):
# result = eval_jaxpr(*args)
# }
# The loop carry is a tuple: (counter, result, args)
from jax._src.lax import control_flow as lax_control_flow

avals_out = tuple(v.aval for v in jaxpr.outvars)
carry_init = (np.int32(0), tuple(map(_dummy_like, avals_out)), args)
def cond(carry):
counter, _, _ = carry
unif = jax.lax.rng_uniform(np.int32(1), np.int32(2), shape=())
unif = lax_internal.rng_uniform(np.int32(1), np.int32(2), shape=())
return counter < unif

def body(carry):
counter, _, args = carry
results = core.eval_jaxpr(jaxpr, (), *args)
return (counter + 1, tuple(results), args)

carry_res = jax.lax.while_loop(cond, body, carry_init)
carry_res = lax_control_flow.while_loop(cond, body, carry_init)
return carry_res[1]

def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr):
Expand All @@ -703,15 +707,17 @@ def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr):
# return eval_jaxpr(*args)
# else:
# return 0
from jax._src.lax import control_flow as lax_control_flow

avals_out = tuple(v.aval for v in jaxpr.outvars)

def remat_comp(*args):
return tuple(core.eval_jaxpr(jaxpr, (), *args))
def dummy_comp(*args):
return tuple(map(_dummy_like, avals_out))

unif = jax.lax.rng_uniform(np.float32(0), np.float32(1), shape=())
return jax.lax.cond(unif < np.float32(2), remat_comp, dummy_comp, *args)
unif = lax_internal.rng_uniform(np.float32(0), np.float32(1), shape=())
return lax_control_flow.cond(unif < np.float32(2), remat_comp, dummy_comp, *args)

mlir.register_lowering(
remat_p, mlir.lower_fun(remat_lowering, multiple_results=True))
Expand Down Expand Up @@ -760,3 +766,63 @@ def name_batcher(args, dims, *, name):
(x,), (d,) = args, dims
return name_p.bind(x, name=name), d
batching.primitive_batchers[name_p] = name_batcher


@functools.wraps(checkpoint)
def checkpoint_wrapper(
fun: Callable,
*,
concrete: bool = False,
prevent_cse: bool = True,
static_argnums: Union[int, Tuple[int, ...]] = (),
policy: Optional[Callable[..., bool]] = None,
) -> Callable:
if concrete:
msg = ("The 'concrete' option to jax.checkpoint / jax.remat is deprecated; "
"in its place, you can use its `static_argnums` option, and if "
"necessary the `jax.ensure_compile_time_eval()` context manager.\n"
"\n"
"For example, if using `concrete=True` for an `is_training` flag:\n"
"\n"
" from functools import partial\n"
"\n"
" @partial(jax.checkpoint, concrete=True)\n"
" def foo(x, is_training):\n"
" if is_training:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"replace it with a use of `static_argnums`:\n"
"\n"
" @partial(jax.checkpoint, static_argnums=(1,))\n"
" def foo(x, is_training):\n"
" ...\n"
"\n"
"If jax.numpy operations need to be performed on static arguments, "
"we can use the `jax.ensure_compile_time_eval()` context manager. "
"For example, we can replace this use of `concrete=True`\n:"
"\n"
" @partial(jax.checkpoint, concrete=True)\n"
" def foo(x, y):\n"
" if y > 0:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"with this combination of `static_argnums` and "
"`jax.ensure_compile_time_eval()`:\n"
"\n"
" @partial(jax.checkpoint, static_argnums=(1,))\n"
" def foo(x, y):\n"
" with jax.ensure_compile_time_eval():\n"
" y_pos = y > 0\n"
" if y_pos:\n"
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n"
"See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n")
raise NotImplementedError(msg)
return checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
static_argnums=static_argnums)
Loading

0 comments on commit c1f65fc

Please sign in to comment.