Skip to content

Commit

Permalink
fix jax.custom_gradient to allow closing over non-autodiff tracers
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Oct 29, 2024
1 parent 66376a3 commit 86a47a7
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 25 deletions.
67 changes: 43 additions & 24 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from jax._src.interpreters.batching import not_mapped
from jax._src.lax import lax
from jax._src.tree_util import (
tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple,
tree_flatten, tree_unflatten, tree_map, treedef_tuple,
register_pytree_node_class, tree_leaves, tree_flatten_with_path, keystr,
treedef_children)
from jax._src.util import (cache, safe_zip, safe_map, split_list, Unhashable,
Expand Down Expand Up @@ -1029,32 +1029,51 @@ def custom_gradient(fun):
>>> print(jax.grad(f, argnums=(0, 1))(3., 4.))
(Array(4., dtype=float32, weak_type=True), Array(3., dtype=float32, weak_type=True))
"""
@custom_vjp
# TODO(mattjj): better debug info
def wrapped_fun(*args, **kwargs):
ans, _ = fun(*args, **kwargs)
return ans

def fwd(*args, **kwargs):
ans, rule = fun(*args, **kwargs)
ans_flat, out_tree = tree_flatten((ans,))
rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
ans_avals = [core.get_aval(x).to_tangent_aval() for x in ans_flat]
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
return ans, Residuals(jaxpr, in_tree(), out_tree, consts)

def bwd(res, cts):
jaxpr, in_tree, out_tree, consts = res
cts_flat, out_tree_ = tree_flatten((cts,))
if out_tree != out_tree_: raise TypeError(f'{out_tree}\n!=\n{out_tree_}')
cts_out = core.eval_jaxpr(jaxpr, consts, *cts_flat)
cts_out = tree_unflatten(in_tree, cts_out)
if treedef_is_leaf(in_tree):
cts_out = (cts_out,)
return cts_out

wrapped_fun.defvjp(fwd, bwd)
args_flat, in_tree = tree_flatten((args, kwargs))
in_avals = [core.get_aval(x) for x in args_flat]
primal_jaxpr, fwd_jaxpr, bwd_jaxpr, consts, out_tree = \
_primal_fwd_bwd(in_tree, in_avals)

@custom_vjp
def primal(consts, args):
return core.eval_jaxpr(primal_jaxpr, (), *consts, *args)
def fwd(consts, args):
ans_res = core.eval_jaxpr(fwd_jaxpr, (), *consts, *args)
return split_list(ans_res, [out_tree.num_leaves])
def bwd(res, cts):
return None, core.eval_jaxpr(bwd_jaxpr, res, *cts)
primal.defvjp(fwd, bwd)

out_flat = primal(consts, args_flat)
return tree_unflatten(out_tree, out_flat)

def _primal_fwd_bwd(in_tree, in_avals):
out_tree, rule_jaxpr = None, None
@lu.wrap_init
def run(*args_flat):
nonlocal rule_jaxpr, out_tree
args, kwargs = tree_unflatten(in_tree, args_flat)
ans, rule = fun(*args, **kwargs)
ans_flat, out_tree = tree_flatten((ans,))
ans_bar_avals = [core.get_aval(x).to_tangent_aval() for x in ans_flat]
rule_, in_tree_ = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
rule_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule_, ans_bar_avals)
out_tree, = treedef_children(out_tree)
return *ans_flat, *consts
fwd_jaxpr, _, fwd_consts, () = pe.trace_to_jaxpr_dynamic(run, in_avals)
fwd_jaxpr = pe.convert_constvars_jaxpr(fwd_jaxpr)
assert out_tree is not None and rule_jaxpr is not None
num_ans = out_tree.num_leaves
num_res = len(fwd_jaxpr.outvars) - num_ans
primal_jaxpr, _ = pe.dce_jaxpr(fwd_jaxpr,
[True] * num_ans + [False] * num_res, True)
return primal_jaxpr, fwd_jaxpr, rule_jaxpr, fwd_consts, out_tree

return wrapped_fun


@register_pytree_node_class
class Residuals:
def __init__(self, jaxpr, in_tree, out_tree, consts):
Expand Down
16 changes: 15 additions & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8976,13 +8976,27 @@ def f(x):
vjp = lambda g: (jnp.cos(x) * jnp.arange(3., 6.),)
return jnp.sum(jnp.sin(x)), vjp

self.assertAllClose(f(jnp.arange(3)), jnp.sum(jnp.sin(jnp.arange(3.))),
self.assertAllClose(f(jnp.arange(3.)), jnp.sum(jnp.sin(jnp.arange(3.))),
check_dtypes=False)
self.assertAllClose(
api.grad(f)(jnp.arange(3.)),
api.grad(lambda x: jnp.sum(jnp.sin(x)))(jnp.arange(3.)) * jnp.arange(3., 6.),
check_dtypes=False)

def test_custom_gradient_jit_closure(self):
@jax.jit
def f(x, y):
y = jnp.sin(y)

@jax.custom_gradient
def g(x):
return y * jnp.sin(x), lambda g: (y * jnp.cos(x) * g,)

return g(x)

g = jax.grad(f)(1., 2.)
self.assertAllClose(g, jnp.sin(2.) * jnp.cos(1.), check_dtypes=False)

def test_custom_gradient_can_return_singleton_value_in_vjp(self):
@jax.custom_gradient
def f(x):
Expand Down

0 comments on commit 86a47a7

Please sign in to comment.