Skip to content

Commit

Permalink
Fix exception causes in api.py (jax-ml#2336)
Browse files Browse the repository at this point in the history
  • Loading branch information
cool-RR authored Mar 4, 2020
1 parent 1e61ba4 commit 52a4131
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,8 @@ def _check_scalar(x):
msg = "Gradient only defined for scalar-output functions. Output {}.".format
try:
aval = core.get_aval(x)
except TypeError:
raise TypeError(msg("was {}".format(x)))
except TypeError as e:
raise TypeError(msg("was {}".format(x))) from e
else:
if isinstance(aval, ShapedArray):
if aval.shape != ():
Expand Down Expand Up @@ -655,7 +655,7 @@ def _check_axis_sizes(tree, vals, dims):
mapped_axis_sizes = {x.shape[d] for x, d in zip(vals, dims) if d is not None}
try:
sizes, = mapped_axis_sizes
except ValueError:
except ValueError as e:
msg = "vmap got inconsistent sizes for array axes to be mapped:\n{}"
# we switch the error message based on whether args is a tuple of arrays,
# in which case we can produce an error message based on argument indices,
Expand All @@ -675,11 +675,11 @@ def _check_axis_sizes(tree, vals, dims):
"axes" if len(idxs) > 1 else "an axis",
size)
for size, idxs in sizes.items()]
raise ValueError(msg.format("\n".join(lines1 + ["so"] + lines2)))
raise ValueError(msg.format("\n".join(lines1 + ["so"] + lines2))) from e
else:
sizes = [x.shape[d] if d is not None else None for x, d in zip(vals, dims)]
sizes = tree_unflatten(tree, sizes)
raise ValueError(msg.format("the tree of axis sizes is:\n{}".format(sizes)))
raise ValueError(msg.format("the tree of axis sizes is:\n{}".format(sizes))) from e

@wraps(fun, docstr=docstr)
def batched_fun(*args):
Expand All @@ -706,10 +706,10 @@ def _flatten_axes(treedef, axis_tree):
add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
try:
tree_multimap(add_leaves, _replace_nones(proxy, axis_tree), dummy)
except ValueError:
except ValueError as e:
msg = ("axes specification must be a tree prefix of the corresponding "
"value, got specification {} for value {}.")
raise ValueError(msg.format(axis_tree, treedef))
raise ValueError(msg.format(axis_tree, treedef)) from e
axes = [None if a is proxy else a for a in axes]
assert len(axes) == treedef.num_leaves
return axes
Expand Down Expand Up @@ -1223,10 +1223,10 @@ def fun(*tangents):
for primal_aval, tangent_aval in zip(primal_avals, tangent_avals):
try:
core.lattice_join(primal_aval, tangent_aval)
except TypeError:
except TypeError as e:
msg = ("linearized function called on tangent values inconsistent with "
"the original primal values.")
raise ValueError(msg)
raise ValueError(msg) from e
dummy = (core.unit,) * len(tangents)
out = eval_jaxpr(jaxpr, consts, *(dummy + tangents))
tangents_out = out[len(out)//2:]
Expand Down Expand Up @@ -1979,8 +1979,8 @@ def __init__(self, shape, dtype):
def __len__(self):
try:
return self.shape[0]
except IndexError:
raise TypeError("len() of unsized object") # same as numpy error
except IndexError as e:
raise TypeError("len() of unsized object") from e # same as numpy error

def __repr__(self):
return "{}(shape={}, dtype={})".format(
Expand Down

0 comments on commit 52a4131

Please sign in to comment.