Skip to content

Commit

Permalink
Merge pull request jax-ml#25593 from mattjj:ref-errors-4
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707733777
  • Loading branch information
Google-ML-Automation committed Dec 19, 2024
2 parents 9041b02 + e528562 commit f65eced
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 32 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ pytype_strict_library(
":config",
":core",
":dtypes",
":state_types",
":traceback_util",
":tree_util",
":util",
Expand Down
30 changes: 30 additions & 0 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
import numpy as np

from jax._src import core
from jax._src import config
from jax._src import dtypes
from jax._src.state.types import AbstractRef
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ShapedArray
from jax._src.tree_util import (
Expand Down Expand Up @@ -737,3 +739,31 @@ def __eq__(self, other):
def register_class_with_attrs(t: type) -> None:
_class_with_attrs.add(t)
_class_with_attrs: set[type] = set()

# TODO(mattjj): make this function faster
def _check_no_aliased_ref_args(dbg, avals, args):
assert config.mutable_array_checks.value
refs: dict[int, int] = {}
for i, (a, x) in enumerate(zip(avals, args)):
if (isinstance(a, AbstractRef) and
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
raise ValueError(
"only one reference to a mutable array may be passed as an argument "
f"to a function, but when tracing {dbg.func_src_info} for {dbg.traced_for} "
f"the mutable array reference of type {a.str_short()} appeared at both "
f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}."
if dbg else
f"at both flat index {dup_idx} and flat index {i}") from None

def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
assert config.mutable_array_checks.value
refs: set[int] = {id(core.get_referent(c)) for c in consts
if isinstance(core.get_aval(c), AbstractRef)}
for i, x in enumerate(args):
if id(core.get_referent(x)) in refs:
a = shaped_abstractify(x)
raise ValueError(
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
f"array reference of type {a.str_short()} was both closed over and "
f"passed as the argument "
f"{dbg.arg_names[i]}" if dbg else "at flat index {i}")
13 changes: 11 additions & 2 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
from jax._src import source_info_util
from jax._src import state
from jax._src import util
from jax._src.api_util import shaped_abstractify
from jax._src.api_util import (
shaped_abstractify, _check_no_aliased_ref_args,
_check_no_aliased_closed_over_refs)
from jax._src.core import ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
Expand Down Expand Up @@ -271,13 +273,20 @@ def scan(f, init, xs, length=None):
xs_avals = [core.get_aval(x) for x in xs_flat]
x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]

if config.mutable_array_checks.value:
in_flat, in_tree = tree_flatten((init, xs))
dbg = pe.debug_info(f, in_tree, None, False, 'scan')
in_avals = tuple(_map(core.get_aval, in_flat))
_check_no_aliased_ref_args(dbg, in_avals, in_flat)

def _create_jaxpr(init):
init_flat, init_tree = tree_flatten(init)
in_flat, in_tree = tree_flatten((init, xs))

carry_avals = tuple(_map(core.get_aval, init_flat))
jaxpr, consts, out_tree, attrs_tracked = _initial_style_jaxpr_attrs(
f, in_tree, (*carry_avals, *x_avals), "scan")
if config.mutable_array_checks.value:
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), in_flat)
out_tree_children = out_tree.children()
if len(out_tree_children) != 2:
msg = "scan body output must be a pair, got {}."
Expand Down
32 changes: 5 additions & 27 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector, shaped_abstractify, check_callable, resolve_argnums,
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info,
hoist_obj_attrs)
hoist_obj_attrs, _check_no_aliased_ref_args,
_check_no_aliased_closed_over_refs)
from jax._src.interpreters import partial_eval as pe
from jax._src.partition_spec import PartitionSpec
from jax._src.interpreters import xla
Expand Down Expand Up @@ -627,7 +628,8 @@ def _infer_params_impl(
flat_fun, in_type, attr_token, dbg,
HashableFunction(res_paths, closure=()),
IgnoreKey(ji.inline))
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args)
if config.mutable_array_checks.value:
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args)
_attr_update(flat_fun, in_type, attr_token, attrs_tracked)

out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
Expand Down Expand Up @@ -764,33 +766,9 @@ def _infer_input_type(fun, dbg, explicit_args) -> tuple[core.AbstractValue, ...]
" static_argnums or static_argnames parameters of jax.jit."
) from None
if config.mutable_array_checks.value:
# TODO(mattjj): make this faster
refs: dict[int, int] = {}
for i, (a, x) in enumerate(zip(avals, explicit_args)):
if (isinstance(a, AbstractRef) and
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
raise ValueError(
"only one reference to a mutable array may be passed as an argument "
f"to a function, but when tracing {dbg.func_src_info} for {dbg.traced_for} "
f"the mutable array reference of type {a.str_short()} appeared at both "
f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}."
if dbg else
f"at both flat index {dup_idx} and flat index {i}") from None
_check_no_aliased_ref_args(dbg, avals, explicit_args)
return tuple(avals)

def _check_no_aliased_closed_over_refs(dbg, consts, args) -> None:
if not config.mutable_array_checks.value: return
refs: set[int] = {id(core.get_referent(c)) for c in consts
if isinstance(core.get_aval(c), AbstractRef)}
for i, x in enumerate(args):
if id(core.get_referent(x)) in refs:
a = shaped_abstractify(x)
raise ValueError(
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
f"array reference of type {a.str_short()} was both closed over and "
f"passed as the argument "
f"{dbg.arg_names[i]}" if dbg else "at flat index {i}")

def _extract_implicit_args(
in_type: Sequence[tuple[core.AbstractValue, bool]],
explicit_args: Sequence[Any]
Expand Down
15 changes: 12 additions & 3 deletions tests/mutable_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ def test_scan_scanned_mut_array(self, jit):
def body_fun(_, index_x):
(index, x) = index_x
x[...] += index
# breakpoint()
return ((), x[...])

x_mut = core.mutable_array(np.arange(5))
Expand Down Expand Up @@ -289,8 +288,18 @@ def test_return_from_scan(self):
ValueError, "traced for scan returned a mutable array reference of type"):
jax.lax.scan(lambda c, x: (core.mutable_array(c), x), 0, jnp.arange(3))

# TODO test_argument_aliases_scan
# TODO test_closure_and_argument_aliases_scan
def test_argument_aliases_scan(self):
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, r"appeared at both c\[0\] and c\[1\]"):
jax.lax.scan(lambda c, _: (None, None), (x_ref, x_ref), None, length=1)

def test_closure_and_argument_aliases_scan(self):
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, r"closed over and passed as the argument y_ref"):
jax.lax.scan(lambda y_ref, _: (x_ref[...] + y_ref[...], None), x_ref,
None, length=1)

def test_return_from_cond(self):
with self.assertRaisesRegex(
Expand Down

0 comments on commit f65eced

Please sign in to comment.