Skip to content

Commit

Permalink
allow custom_vjp bwd to return Nones for zeros
Browse files Browse the repository at this point in the history
This change sets up some internal users so that we can then land jax-ml#4008.
  • Loading branch information
mattjj committed Oct 15, 2020
1 parent fb6b3bf commit 3a75145
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions jax/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .util import safe_zip, safe_map, split_list
from .api_util import flatten_fun_nokwargs, argnums_partial, wrap_hashably
from .abstract_arrays import raise_to_shaped
from .ad_util import Zero, stop_gradient_p
from .ad_util import Zero, zeros_like_aval, stop_gradient_p
from .interpreters import partial_eval as pe
from .interpreters import ad
from .interpreters import batching
Expand Down Expand Up @@ -463,9 +463,10 @@ def __call__(self, *args, **kwargs):
f_, dyn_args = lu.wrap_init(self.fun), args
fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
args_flat, in_tree = tree_flatten(dyn_args)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd, in_tree)
flat_bwd = _flatten_bwd(bwd, in_tree, out_trees)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
if _initial_style_staging():
out_flat = custom_vjp_call_jaxpr(flat_fun, flat_fwd, flat_bwd,
*args_flat, out_trees=out_trees)
Expand Down Expand Up @@ -493,21 +494,33 @@ def _flatten_fwd(in_tree, *args):
yield res + out, (out_tree, res_tree)

@lu.transformation
def _flatten_bwd(in_tree, out_trees, *args):
def _flatten_bwd(in_tree, in_avals, out_trees, *args):
out_tree, res_tree = out_trees()
res, cts_out = split_list(args, [res_tree.num_leaves])
py_res = tree_unflatten(res_tree, res)
py_cts_out = tree_unflatten(out_tree, cts_out)
py_cts_in = yield (py_res, py_cts_out), {}
cts_in, in_tree2 = tree_flatten(py_cts_in)
if in_tree != in_tree2:
# For each None in py_cts_in, indicating an argument for which the rule
# produces no cotangent, we replace it with a pytree with the structure of the
# corresponding subtree of in_tree and with leaves of a non-pytree sentinel
# object, to be replaced with Nones in the final returned result.
zero = object() # non-pytree sentinel to replace Nones in py_cts_in
py_cts_in_ = tuple(zero if ct is None else ct for ct in py_cts_in)
dummy = tree_unflatten(in_tree, [object()] * in_tree.num_leaves)
cts_in_flat = []
append_cts = lambda x, d: cts_in_flat.extend([x] * len(tree_flatten(d)[0]))
try:
tree_multimap(append_cts, py_cts_in_, dummy)
except ValueError:
_, in_tree2 = tree_flatten(py_cts_in)
msg = ("Custom VJP rule must produce an output with the same container "
"(pytree) structure as the args tuple of the primal function, "
"and in particular must produce a tuple of length equal to the "
"number of arguments to the primal function, but got VJP output "
"structure {} for primal input structure {}.")
raise TypeError(msg.format(in_tree2, in_tree)) from None
yield cts_in
yield [zeros_like_aval(aval.at_least_vspace()) if ct is zero else ct
for aval, ct in zip(in_avals, cts_in_flat)]


class CustomVJPCallPrimitive(core.CallPrimitive):
Expand Down

0 comments on commit 3a75145

Please sign in to comment.