Skip to content

Commit

Permalink
improve scan error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 23, 2023
1 parent d857187 commit ba2ff51
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 50 deletions.
1 change: 0 additions & 1 deletion jax/_src/lax/control_flow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
tree_unflatten(tree2, avals2))
raise TypeError(f"{what} must have identical types, got\n{diff}.")


def _check_tree(func_name, expected_name, actual_tree, expected_tree, has_aux=False):
if has_aux:
actual_tree_children = actual_tree.children()
Expand Down
101 changes: 76 additions & 25 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# limitations under the License.
"""Module for the loop primitives."""
from functools import partial
import inspect
import itertools
import operator

from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar

import jax
Expand All @@ -30,7 +30,8 @@
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
tree_map)
tree_map, tree_flatten_with_path, keystr)
from jax._src.tree_util import equality_errors
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api
Expand All @@ -45,26 +46,13 @@
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.ufuncs import logaddexp
from jax._src.traceback_util import api_boundary
from jax._src.util import (
partition_list,
safe_map,
safe_zip,
split_list,
unzip2,
weakref_lru_cache,
)
from jax._src.util import (partition_list, safe_map, safe_zip, split_list,
unzip2, weakref_lru_cache)
import numpy as np

from jax._src.lax.control_flow.common import (
_abstractify,
_avals_short,
_check_tree_and_avals,
_initial_style_jaxpr,
_make_closed_jaxpr,
_prune_zeros,
_typecheck_param,
allowed_effects,
)
_abstractify, _avals_short, _check_tree_and_avals, _initial_style_jaxpr,
_make_closed_jaxpr, _prune_zeros, _typecheck_param, allowed_effects)

_map = safe_map
zip = safe_zip
Expand Down Expand Up @@ -260,14 +248,11 @@ def _create_jaxpr(init):
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
new_init_flat, changed = _promote_weak_typed_inputs(init_flat, carry_avals, carry_avals_out)
if changed:
new_init = tree_unflatten(init_tree, new_init_flat)
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(new_init)
init = tree_unflatten(init_tree, new_init_flat)
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
in_flat, jaxpr, consts, out_tree, out_tree_children = rest

_check_tree_and_avals("scan carry output and input",
# Extract the subtree and avals for the first element of the return tuple
out_tree_children[0], carry_avals_out,
init_tree, carry_avals)
_check_scan_carry_type(f, init, out_tree_children[0], carry_avals_out)
disallowed_effects = allowed_effects.filter_not_in(jaxpr.effects)
if disallowed_effects:
raise NotImplementedError(
Expand All @@ -280,6 +265,71 @@ def _create_jaxpr(init):
unroll=unroll)
return tree_unflatten(out_tree, out)

def _check_scan_carry_type(body_fun, in_carry, out_carry_tree, out_avals):
try:
sig = inspect.signature(body_fun)
except (ValueError, TypeError):
sig = None
carry_name = sig and list(sig.parameters)[0]
if carry_name:
component = lambda p: (f'the input carry component {carry_name}{keystr(p)}'
if p else f'the input carry {carry_name}')
else:
component = lambda p: (f'the input carry at path {keystr(p)}'
if p else 'the input carry')
leaves_and_paths, in_carry_tree = tree_flatten_with_path(in_carry)
paths, in_carry_flat = unzip2(leaves_and_paths)
in_avals = _map(_abstractify, in_carry_flat)
if in_carry_tree != out_carry_tree:
try:
out_carry = tree_unflatten(out_carry_tree, out_avals)
except:
out_carry = None

if out_carry is None:
differences = [f'the input tree structure is:\n{in_carry_tree}\n',
f'the output tree structure is:\n{out_carry_tree}\n']
else:
differences = '\n'.join(
f' * {component(path)} is a {thing1} but the corresponding component '
f'of the carry output is a {thing2}, so {explanation}\n'
for path, thing1, thing2, explanation
in equality_errors(in_carry, out_carry))
raise TypeError(
"Scanned function carry input and carry output must have the same "
"pytree structure, but they differ:\n"
f"{differences}\n"
"Revise the scanned function so that its output is a pair where the "
"first element has the same pytree structure as the first argument."
)
if not all(_map(core.typematch, in_avals, out_avals)):
differences = '\n'.join(
f' * {component(path)} has type {in_aval.str_short()}'
' but the corresponding output carry component has type '
f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}\n'
for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
if not core.typematch(in_aval, out_aval))
raise TypeError(
"Scanned function carry input and carry output must have equal types "
"(e.g. shapes and dtypes of arrays), "
"but they differ:\n"
f"{differences}\n"
"Revise the scanned function so that all output types (e.g. shapes "
"and dtypes) match the corresponding input types."
)

def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str:
assert not core.typematch(a1, a2)
if isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray):
dtype_mismatch = a1.dtype != a2.dtype
shape_mismatch = a1.shape != a2.shape
return (', so ' * (dtype_mismatch or shape_mismatch) +
'the dtypes do not match' * dtype_mismatch +
' and also ' * (dtype_mismatch and shape_mismatch) +
'the shapes do not match' * shape_mismatch)
return ''


def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear,
f_impl, x_avals, y_avals):
consts, init, xs = split_list(args, [num_consts, num_carry])
Expand Down Expand Up @@ -1111,6 +1161,7 @@ def _create_jaxpr(init_val):
body_nconsts=len(body_consts), body_jaxpr=body_jaxpr)
return tree_unflatten(body_tree, outs)


def _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts
) -> effects.Effects:
joined_effects = set()
Expand Down
64 changes: 63 additions & 1 deletion jax/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import collections
from dataclasses import dataclass
Expand Down Expand Up @@ -427,6 +428,67 @@ def prefix_errors(prefix_tree: Any, full_tree: Any,
) -> List[Callable[[str], ValueError]]:
return list(_prefix_error((), prefix_tree, full_tree, is_leaf))

def equality_errors(
tree1: Any, tree2: Any, is_leaf: Optional[Callable[[Any], bool]] = None,
) -> Iterable[Tuple[KeyPath, str, str, str]]:
yield from _equality_errors((), tree1, tree2, is_leaf)

# TODO(mattjj): maybe share some logic with _prefix_error?
def _equality_errors(path, t1, t2, is_leaf):
# If both are leaves, this isn't a structure equality error.
if (treedef_is_strict_leaf(tree_structure(t1, is_leaf=is_leaf)) and
treedef_is_strict_leaf(tree_structure(t2, is_leaf=is_leaf))): return

# The trees may disagree because they are different types:
if type(t1) != type(t2):
yield path, str(type(t1)), str(type(t2)), 'their Python types differ'
return # no more errors to find

# Or they may disagree because their roots have different numbers or keys of
# children (with special-case handling of list/tuple):
if isinstance(t1, (list, tuple)):
assert type(t1) == type(t2)
if len(t1) != len(t2):
yield (path,
f'{type(t1).__name__} of length {len(t1)}',
f'{type(t2).__name__} of length {len(t2)}',
'the lengths do not match')
return # no more errors to find
t1_children, t1_meta = flatten_one_level(t1)
t2_children, t2_meta = flatten_one_level(t2)
t1_keys, t2_keys = _child_keys(t1), _child_keys(t2)
try:
diff = ' '.join(repr(k.key) for k in
set(t1_keys).symmetric_difference(set(t2_keys)))
except:
diff = ''
if len(t1_children) != len(t2_children):
yield (path,
f'{type(t1)} with {len(t1_children)} child'
f'{"ren" if len(t1_children) > 1 else ""}',
f'{type(t2)} with {len(t2_children)} child'
f'{"ren" if len(t2_children) > 1 else ""}',
'the numbers of children do not match' +
(diff and f', with the symmetric difference of key sets: {{{diff}}}')
)
return # no more errors to find

# Or they may disagree if their roots have different pytree metadata:
if t1_meta != t2_meta:
yield (path,
f'{type(t1)} with pytree metadata {t1_meta}',
f'{type(t2)} with pytree metadata {t2_meta}',
'the pytree node metadata does not match')
return # no more errors to find

# If the root types and numbers of children agree, there must be a mismatch in
# a subtree, so recurse:
assert t1_keys == t2_keys, \
f"equal pytree nodes gave different tree keys: {t1_keys} and {t2_keys}"
for k, c1, c2 in zip(t1_keys, t1_children, t2_children):
yield from _equality_errors((*path, k), c1, c2, is_leaf)


# TODO(ivyzheng): Remove old APIs when all users migrated.

class _DeprecatedKeyPathEntry(NamedTuple):
Expand Down Expand Up @@ -800,7 +862,7 @@ def _prefix_error(key_path: KeyPath, prefix_tree: Any, full_tree: Any,
("equal pytree nodes gave differing prefix_tree_keys: "
f"{prefix_tree_keys} and {full_tree_keys}")
for k, t1, t2 in zip(prefix_tree_keys, prefix_tree_children, full_tree_children):
yield from _prefix_error(tuple((*key_path, k)), t1, t2)
yield from _prefix_error((*key_path, k), t1, t2)


# TODO(jakevdp) remove these deprecated wrappers & their imports in jax/__init__.py
Expand Down
107 changes: 84 additions & 23 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,31 +1762,92 @@ def plus_one(p, iter_idx):
'scan got value with no leading axis to scan over.*',
lambda: lax.scan(plus_one, p0, list(range(5))))

def testScanTypeErrors(self):
"""Test typing error messages for scan."""
a = jnp.arange(5)
# Body output not a tuple
with self.assertRaisesRegex(TypeError,
def testScanBodyOutputError(self):
with self.assertRaisesRegex(
TypeError,
re.escape("scan body output must be a pair, got ShapedArray(float32[]).")):
lax.scan(lambda c, x: np.float32(0.), 0, a)
with self.assertRaisesRegex(TypeError,
re.escape("scan carry output and input must have same type structure, "
f"got {tree_util.tree_structure((0, 0, 0,))} "
f"and {tree_util.tree_structure((1, (2, 3)))}")):
lax.scan(lambda c, x: ((0, 0, 0), x), (1, (2, 3)), a)
with self.assertRaisesRegex(TypeError,
re.escape("scan carry output and input must have same type structure, "
f"got {tree_util.tree_structure(a)} and {tree_util.tree_structure(None)}.")):
lax.scan(lambda c, x: (0, x), None, a)
with self.assertRaisesWithLiteralMatch(
lax.scan(lambda c, x: np.float32(0.), 0, jnp.arange(5.))

def testScanBodyCarryPytreeMismatchErrors(self):
with self.assertRaisesRegex(
TypeError,
"scan carry output and input must have identical types, got\n"
"DIFFERENT ShapedArray(int32[]) vs. ShapedArray(float32[])."):
lax.scan(lambda c, x: (np.int32(0), x), np.float32(1.0), a)
with self.assertRaisesRegex(TypeError,
re.escape("scan carry output and input must have same type structure, "
f"got {tree_util.tree_structure(a)} and {tree_util.tree_structure((1, 2))}.")):
lax.scan(lambda c, x: (0, x), (1, 2), a)
re.escape("Scanned function carry input and carry output must have "
"the same pytree structure, but they differ:\n"
" * the input carry c is a tuple of length 2")):
lax.scan(lambda c, x: ((0, 0, 0), x), (1, (2, 3)), jnp.arange(5.))

with self.assertRaisesRegex(
TypeError,
re.escape("Scanned function carry input and carry output must have the "
"same pytree structure, but they differ:\n"
" * the input carry x is a tuple of length 2")):
lax.scan(lambda x, _: ((x[0].astype('float32'),), None),
(jnp.array(0, 'int32'),) * 2, None, length=1)

with self.assertRaisesRegex(
TypeError,
re.escape("Scanned function carry input and carry output must have the "
"same pytree structure, but they differ:\n"
" * the input carry x is a <class 'tuple'> but the corres")):
jax.lax.scan(lambda x, _: ([x[0].astype('float32'),] * 2, None),
(jnp.array(0, 'int32'),) * 2, None, length=1)

with self.assertRaisesRegex(
TypeError,
re.escape("Scanned function carry input and carry output must have the "
"same pytree structure, but they differ:\n"
" * the input carry x is a <class 'dict'> with 1 child but")):
jax.lax.scan(lambda x, _: ({'a': x['a'], 'b': x['a']}, None),
{'a': jnp.array(0, 'int32')}, None, length=1)

with self.assertRaisesRegex(
TypeError,
re.escape("Scanned function carry input and carry output must have the "
"same pytree structure, but they differ:\n"
" * the input carry component x[0] is a <class 'dict'> with "
"1 child but the corresponding component of the carry "
"output is a <class 'dict'> with 2 children")):
jax.lax.scan(lambda x, _: (({'a': x[0]['a'], 'b': x[0]['a']},) * 2, None),
({'a': jnp.array(0, 'int32')},) * 2, None, length=1)

def testScanBodyCarryTypeMismatchErrors(self):
with self.assertRaisesRegex(
TypeError,
re.escape("Scanned function carry input and carry output must have equal "
"types (e.g. shapes and dtypes of arrays), but they differ:\n"
" * the input carry x has type int32[] but the corresponding "
"output carry component has type float32[], so the dtypes do "
"not match"
)):
jax.lax.scan(lambda x, _: (x.astype('float32'), None),
jnp.array(0, 'int32'), None, length=1)

with self.assertRaisesRegex(
TypeError,
re.escape("Scanned function carry input and carry output must have equal "
"types (e.g. shapes and dtypes of arrays), but they differ:\n"
" * the input carry component x[1] has type int32[] but the "
"corresponding output carry component has type float32[], "
"so the dtypes do not match"
)):
jax.lax.scan(lambda x, _: ((x[0], x[1].astype('float32')), None),
(jnp.array(0, 'int32'),) * 2, None, length=1)

with self.assertRaisesRegex(
TypeError,
re.escape("Scanned function carry input and carry output must have equal "
"types (e.g. shapes and dtypes of arrays), but they differ:\n"
" * the input carry component x[0] has type int32[] but the "
"corresponding output carry component has type float32[], "
"so the dtypes do not match\n\n"
" * the input carry component x[1] has type int32[] but the "
"corresponding output carry component has type float32[1,1], "
"so the dtypes do not match and also the shapes do not match"
)):
jax.lax.scan(lambda x, _: ((x[0].astype('float32'),
x[1].astype('float32').reshape(1, 1),
x[2]), None),
(jnp.array(0, 'int32'),) * 3, None, length=1)

@parameterized.named_parameters(
{"testcase_name": f"_{scan_name}",
Expand Down

0 comments on commit ba2ff51

Please sign in to comment.