Skip to content

Commit

Permalink
Add support for custom derivatives in jax.experimental.callback
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Mar 22, 2021
1 parent 9a2a1ad commit 252bd6c
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 14 deletions.
73 changes: 59 additions & 14 deletions jax/experimental/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
from jax import core
from jax.core import Trace, Tracer, jaxpr_as_fun
from jax import lax
from jax import custom_derivatives as cd
from jax.interpreters import partial_eval as pe
from jax import linear_util as lu
from jax._src.util import partial, safe_map, wraps, split_list
from jax._src.lax import control_flow as lcf

import inspect
from jax.api_util import flatten_fun_nokwargs
from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, tree_leaves, tree_map
from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, tree_map

map = safe_map

Expand Down Expand Up @@ -114,6 +116,14 @@ def _callback_fun(callback, strip_calls, *in_vals, **params):
del main
yield out_vals

def callback_jaxpr(closed_jaxpr, callback, strip_calls):
fun = lu.wrap_init(jaxpr_as_fun(closed_jaxpr))
fun = callback_subtrace(fun)
fun = _callback_fun(fun, callback, strip_calls)
avals_in = closed_jaxpr.in_avals
jaxpr_out, consts = cd._initial_style_jaxpr(fun, avals_in)
return core.ClosedJaxpr(jaxpr_out, consts)

def _check_callable(fun):
if not callable(fun):
raise TypeError(f"Expected a callable value, got {fun}")
Expand Down Expand Up @@ -164,18 +174,20 @@ def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
return [CallbackTracer(self, val) for val in vals_out]

def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
# This implementation just drops the custom derivative rule.
# TODO(sharadmv): don't drop the custom derivative rule
del primitive, jvp # Unused.
return fun.call_wrapped(*tracers)
vals_in = [t.val for t in tracers]
fun = callback_subtrace(fun, self.main)
jvp = callback_subtrace(jvp, self.main)
out = primitive.bind(fun, jvp, *vals_in)
return safe_map(self.pure, out)

def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
out_trees):
# This implementation just drops the custom derivative rule.
# TODO(sharadmv): don't drop the custom derivative rule
del primitive, fwd, bwd, out_trees # Unused.
return fun.call_wrapped(*tracers)

vals_in = [t.val for t in tracers]
fun = callback_subtrace(fun, self.main)
fwd = callback_subtrace(fwd, self.main)
bwd = callback_subtrace(bwd, self.main)
out = primitive.bind(fun, fwd, bwd, *vals_in, out_trees=out_trees)
return safe_map(self.pure, out)

custom_callback_rules: Dict[Any, Any] = {}

Expand All @@ -190,14 +202,13 @@ def _scan_callback_rule(trace, *tracers, reverse, length, num_consts, num_carry,

body_fun = jaxpr_as_fun(jaxpr)

def new_body(carry, x):
flat_args = tree_leaves((carry, x))
out = body_fun(*(const_vals + flat_args))
def new_body(*vals):
out = body_fun(*vals)
out_carry, y = split_list(out, [num_carry])
return out_carry, y
main = trace.main
new_body = callback_transform(new_body, main.callback, strip_calls=main.strip_calls) # type: ignore
in_tree = tree_structure(tuple(carry_avals + xs_avals))
in_tree = tree_structure(carry_avals + xs_avals)
new_jaxpr, new_consts, _ = lcf._initial_style_jaxpr(
new_body, in_tree, tuple(carry_avals + x_avals))
vals = tuple(it.chain(new_consts, carry_vals, xs_vals))
Expand Down Expand Up @@ -242,3 +253,37 @@ def body(*carry):
return safe_map(trace.pure, out)

custom_callback_rules[lax.while_p] = _while_callback_rule

def _custom_derivative_call_jaxpr_callback_rule(primitive, trace, *tracers,
fun_jaxpr, num_consts, **params):
main = trace.main
vals = [t.val for t in tracers]

new_closed_jaxpr = callback_jaxpr(fun_jaxpr, main.callback, strip_calls=main.strip_calls)
if primitive == cd.custom_jvp_call_jaxpr_p:
thunk_name = 'jvp_jaxpr_thunk'
elif primitive == cd.custom_vjp_call_jaxpr_p:
thunk_name = 'fwd_jaxpr_thunk'
params['bwd'] = callback_subtrace(params['bwd'], main)
else:
raise NotImplementedError(primitive)

thunk = params.pop(thunk_name)
@pe._memoize
def new_thunk():
thunk_jaxpr = core.ClosedJaxpr(*thunk())
closed_jaxpr = callback_jaxpr(thunk_jaxpr, main.callback, main.strip_calls)
return closed_jaxpr.jaxpr, closed_jaxpr.literals

params[thunk_name] = new_thunk
new_fun_jaxpr, new_consts = new_closed_jaxpr.jaxpr, new_closed_jaxpr.literals
closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(new_fun_jaxpr), ())
new_num_consts = len(new_consts) + num_consts
out = primitive.bind(*it.chain(new_consts, vals), fun_jaxpr=closed_fun_jaxpr,
num_consts=new_num_consts, **params)
return safe_map(trace.pure, out)

custom_callback_rules[cd.custom_jvp_call_jaxpr_p] = partial(
_custom_derivative_call_jaxpr_callback_rule, cd.custom_jvp_call_jaxpr_p)
custom_callback_rules[cd.custom_vjp_call_jaxpr_p] = partial(
_custom_derivative_call_jaxpr_callback_rule, cd.custom_vjp_call_jaxpr_p)
131 changes: 131 additions & 0 deletions tests/callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import jax.numpy as jnp
from jax import lax
from jax import jit
from jax import grad

from jax.config import config
config.parse_flags_with_absl()
Expand Down Expand Up @@ -158,6 +159,136 @@ def body(i, x):
}
self.assertAllClose(rewrite(f, rewrites)(x), 11)

def testRewriteThroughCustomVJP(self):

@jax.custom_gradient
def f(x):
return x * 2, lambda g: g + x

x = 2.
self.assertAllClose(f(x), 4.)
self.assertAllClose(grad(f)(x), 3.)

rewrites = {
lax.mul_p: lambda x, y: x / y
}
g = rewrite(f, rewrites)

self.assertAllClose(g(x), 1.)
self.assertAllClose(grad(g)(x), 3.)

rewrites = {
lax.add_p: lambda x, y: x - y
}
g = rewrite(f, rewrites)

self.assertAllClose(g(x), 4.)
self.assertAllClose(grad(g)(x), -1.)

def testRewriteThroughCustomVJPInScan(self):

@jax.custom_gradient
def foo(x):
return x * 2, lambda g: g + x

def f(x):
out, _ = lax.scan(lambda c, _: (foo(c), None), x, None, length=1)
return out

x = 2.
self.assertAllClose(f(x), 4.)
self.assertAllClose(grad(f)(x), 3.)

rewrites = {
lax.mul_p: lambda x, y: x / y
}
g = rewrite(f, rewrites)

self.assertAllClose(g(x), 1.)
self.assertAllClose(grad(g)(x), 3.)

rewrites = {
lax.add_p: lambda x, y: x * y
}
g = rewrite(f, rewrites)

self.assertAllClose(g(x), 4.)
self.assertAllClose(grad(g)(x), 2.)

def testRewriteThroughCustomJVP(self):

@jax.custom_jvp
def f(x):
return x + 2

@f.defjvp
def f_jvp(primals, tangents):
x, = primals
d, = tangents
return f(x), x * d

x = 2.
self.assertAllClose(f(x), 4.)
f_primal, jvp = jax.jvp(f, (x,), (1.,))
self.assertAllClose(f_primal, 4.)
self.assertAllClose(jvp, 2.)
self.assertAllClose(grad(f)(x), 2.)

rewrites = {
lax.add_p: lambda x, y: x - y
}
g = rewrite(f, rewrites)

self.assertAllClose(g(x), 0.)
g_primal, jvp = jax.jvp(g, (x,), (1.,))
self.assertAllClose(g_primal, 0.)
self.assertAllClose(jvp, 2.)
self.assertAllClose(grad(g)(x), 2.)

def testRewriteThroughCustomJVPInScan(self):

@jax.custom_jvp
def foo(x):
return x + 2

@foo.defjvp
def foo_jvp(primals, tangents):
x, = primals
d, = tangents
return f(x), x * d
def f(x):
out, _ = lax.scan(lambda c, _: (foo(c), None), x, None, length=1)
return out

x = 2.
self.assertAllClose(f(x), 4.)
f_primal, jvp = jax.jvp(f, (x,), (1.,))
self.assertAllClose(f_primal, 4.)
self.assertAllClose(jvp, 2.)
self.assertAllClose(grad(f)(x), 2.)

rewrites = {
lax.add_p: lambda x, y: x - y
}
g = rewrite(f, rewrites)

self.assertAllClose(g(x), 0.)
g_primal, jvp = jax.jvp(g, (x,), (1.,))
self.assertAllClose(g_primal, 0.)
self.assertAllClose(jvp, 2.)
self.assertAllClose(grad(g)(x), 2.)

rewrites = {
lax.mul_p: lambda x, y: x + y
}
g = rewrite(f, rewrites)

self.assertAllClose(g(x), 4.)
g_primal, jvp = jax.jvp(g, (x,), (1.,))
self.assertAllClose(g_primal, 4.)
self.assertAllClose(jvp, 3.)
self.assertAllClose(grad(g)(x), 1.)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 252bd6c

Please sign in to comment.