This is a design document, explaining some of the thinking behind the design and
implementation of jax.custom_jvp
and jax.custom_vjp
. For user-oriented
documentation, see the tutorial notebook.
There are two ways to define differentiation rules in JAX:
- using
jax.custom_jvp
andjax.custom_vjp
to define custom differentiation rules for Python functions that are already JAX-transformable; and - defining new
core.Primitive
instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.
This document is about #1 only.
We want users to customize the forward- and/or reverse-mode differentiation behavior of their code. This customization
- should have a clear and consistent semantics in how it works and how it composes with other JAX transformations; and
- should be flexible in supporting use cases and workflows like in Autograd and PyTorch, including cases involving differentiation of Python control flow and workflows for NaN debugging.
As JAX developers we want to write library functions, like
logit
and
expit
,
that are defined in terms of other primitives, but for the purposes of
differentiation have primitive-like behavior in the sense that we want to define
custom differentiation rules for them, which may be more numerically stable or
performant. In particular, we don't want to have to specify vmap
or jit
rules for functions like logit
and expit
.
As a stretch goal, we’d like to make JAX a great environment for power users
looking to add custom differentiation rules for higher-order functions like
fixed_point
, odeint
, etc.; this design doc won’t solve that problem, but we
want to be confident we’re not going to preclude good solutions to that problem.
That is, our primary goals are
- solve the vmap-removes-custom-jvp semantics problem (#1249), and
- allow Python in custom VJPs, e.g. to debug NaNs (#1275).
Secondary goals are
3. clean up and simplify user experience (symbolic zeros, kwargs, etc)
4. make progress towards a world where users can easily add fixed_point
,
odeint
, root
, etc.
Overall, we want to close #116, #1097, #1249, #1275, #1366, #1723, #1670, #1875, #1938, and replace the custom_transforms machinery (from #636, #818, and others).
Here are objectives we're not aiming to achieve:
- The
custom_transforms
machinery aimed to provide a transformation-generic mechanism for customizing behavior, in principle (though never really used in practice) allowing users to customize rules for any transformation while somehow inheriting the “transparent” behavior for others. We are instead only going to solve the customization problem for differentiation (JVP and VJP, separately). Differentiation is the only case actually requested, and by specializing to differentiation we can reduce complexity and improve flexibility. To control all rules one can just write a primitive. - We’re not going to prioritize mathematical aesthetics over flexibility
and clarity on the user side, and simplicity on the implementation side. In
particular, while the custom VJP signature
a -> (b, CT b --o CT a)
is mathematically pleasing, if it’s hard to implement in a Python mechanism because of the closure in the return type, we’re fine doing something that handles residuals more explicitly. - Serialization support, of the form where the staged-out serialized program representation can be loaded and further JAX-transformed as opposed to just evaluated, is currently out of scope for these custom JVP/VJP transformation rules. Serialization may be useful not only for researchers who want to save some representation of their computation (and transform it after loading it), but also for future considerations like having jaxpr transformations implemented outside Python, or having jaxprs as an MLIR dialect. By defining this as a non-goal for the purpose of this design, we have fewer constraints on where we can stash Python callables.
The vmap-removes-custom-jvp semantics problem is that vmap does not compose
properly with differentiation of functions with custom_transforms
rules:
# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
return 2. * x
# f_vjp :: a -> (b, CT b --o CT a)
def f_vjp(x):
return f(x), lambda g: 3. * x # 3 instead of 2
jax.defvjp_all(f, f_vjp)
grad(f)(1.) # 3.
vmap(grad(f))(np.ones(4)) # [3., 3., 3., 3.]
grad(lambda x: vmap(f)(x).sum())(np.ones(4)) # [2., 2., 2., 2.]
The last grad-of-vmap line has an unexpected result! In general, applying
vmap
, or really any non-differentiation transformation, has the effect of
removing the custom differentiation rule. (Applying jvp
causes a failure when
a custom VJP rule is defined.)
The problem exists because transformations are like rewrites, and the vmap
transformation effectively rewrites the function to no longer call the
newly-introduced primitive for which there is a custom rule (and hence grad
then doesn’t produce the custom rule’s result). In more detail, the
custom_transforms
machinery sets things up so that evaluating f(x)
applies
the function
{ lambda ; ; a.
let b = f_primitive a
in [b] }
where f_primitive
is a new primitive (introduced for every custom_transforms
function and in fact for every call of the function) to which the custom VJP
rule is associated. When we evaluate grad(f)(x)
, the differentiation machinery
encounters f_primitive
and processes it with the custom rule.
However, because f_primitive
is transparent to vmap
, in the sense that
vmap
operates on (effectively by inlining) the definition of f_primitive
,
the function vmap(f)
is effectively
{ lambda ; ; a.
let b = mul 2. a
in [b] }
In words, vmap
rewrites the function in terms of its underlying primitives and
their transformation rules, removing f_primitive
entirely.
More generally, because vmap(f)
has semantics defined in terms of calls to
f, it is semantically inconsistent to remove the custom derivative rule. That
is, since we define
vmap(f)(xs) == np.stack([f(x) for x in xs])
we must have
jvp(vmap(f))(xs) == jvp(lambda xs: np.stack([f(x) for x in xs]))
yet this property is not observed when f
has a custom derivative rule defined,
as the custom derivative rule is used in the right-hand version but not the
left-hand one.
This issue isn’t specific to vmap
; it applies to all transformations for which
the semantics of transforming a function f
are defined in terms of calls to
the function f
, rather than rewriting it into another function. The mask
transformation also falls into this class. Differentiation transforms and the
hypothetical all-unary-functions-become-cosine transform are not in this class.
(The interaction between additional custom rules, like custom vmap
rules, is
likely to get even more complex, suggesting the problem framing of
custom_transforms
is too broad.)
In JAX, as in Autograd and PyTorch but not TF1, differentiation of a Python function is performed while the function is being executed and traced. This behavior delights users for a few reasons.
First and most importantly, it enables pdb-based workflows, e.g. for
inspecting numerics or catching NaNs. That is, users can employ the standard
Python debugger and other Python-native tools to debug their code, even being
able to inspect runtime values to understand numerical behavior on examples and
to catch fundamentally runtime errors like NaNs. In fact, just while working on
the PR corresponding to this design, especially on the odeint
primitive, I
used runtime value inspection to debug issues many times, increasing my
confidence that this is a key user workflow in Python. One especially handy
trick, which I’ve used in both JAX and Autograd many times, is the ability to
insert a debugger breakpoint in a custom VJP rule to enter a debugger at a
specific point in the backward pass.
Second, it allows differentiation of Python native control flow. We’re not sure how often this is used in practice in finalized software artifacts, but when users first poke around JAX or Autograd they’re often impressed by this freedom. There’s a reason we include it at the top of our JAX and Autograd READMEs, slide decks, and demos. Ceding this capability would be a step backward from Autograd. We want JAX to have the best automatic differentiation.
However, the custom_transforms
machinery does not provide this Python-support
flexibility. That is, because it’s implemented in terms of up-front jaxpr
formation from the Python code for both the user function and custom
differentiation rules, code like this leads to an abstract value tracing error:
# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
if x > 0:
return x
else:
return 0.
def f_vjp(x):
return ...
jax.defvjp_all(f, f_vjp)
grad(f)(1.) # Error!
The main idea is that dougalm@ already solved
these problems with core.call
. That is, we can frame the task of specifying
a custom JVP rule for a user function in terms of a new Python-level call
primitive (not to be added to the jaxpr language; see below). This new call
primitive has a user Python function associated with it just like core.call
,
but additionally has a second Python callable representing the JVP rule. Let’s
refer to this new call primitive as custom_jvp_call
.
Transformations like vmap
interact with custom_jvp_call
as with core.call
:
they effectively pass right through it and are applied to the underlying Python
callables. Schematically, writing in terms of curried versions of the primitives
for convenience, analogously to how vmap
interacts with core.call
by
applying to the function to be called:
vmap(call(f)) == call(vmap(f))
for the new primitive custom_jvp_call
we simply apply vmap
to the two
functions it entails:
vmap(custom_jvp_call(f, f_jvp)) == custom_jvp_call(vmap(f), vmap(f_jvp))
This behavior means we’ve solved the vmap-removes-custom-jvp semantics problem.
The jvp
transformation interacts as one might expect: it just calls f_jvp
,
jvp(call(f)) == call(jvp(f))
jvp(custom_jvp_call(f, f_jvp)) == f_jvp
Because custom_jvp_call
acts like core.call
(and not like xla.xla_call
) in
that it doesn’t raise the abstraction level of its inputs (because it’s not
delaying anything or staging anything out), it means we’ve solved the Python
flexibility problem: there are no constraints
on the user Python function (above the usual functional programming constraints
required by jvp
or vjp
).
What about evaluation and compilation? These are two ways to “exit” the JAX system, in the sense that no additional transformations can be applied after these steps. As a result, their rules are trivial:
eval(call(f)) == eval(f)
jit(call(f)) == hlo_call(jit(f))
eval(custom_jvp_call(f, f_jvp)) == eval(f)
jit(custom_jvp_call(f, f_jvp)) == hlo_call(jit(f))
In words, if a JVP rule hasn’t already rewritten custom_jvp_call(f, f_jvp)
into f_jvp
, when we get to the point of evaluation with eval
or staging out
to XLA with jit
, differentiation is never going to be applied, so we just
ignore f_jvp
and behave just like core.call
. However, due to the wrinkle
discussed next, the partial eval rule for custom_jvp_call
must be a bit more
complex, since partial evaluation isn’t just used to stage out to XLA with
jit
.
The only remaining wrinkle has to do with “initial-style” jaxpr-forming
primitives, like lax.scan
, and their transformation rules. These represent a
different kind of “staging out to a jaxpr” than that for compilation because we
can perform additional transformations on the staged-out jaxpr. That is, when
lax.scan
forms a jaxpr, it does not exit the transformation system, since when
we apply a jvp or vmap to a lax.scan
we need to apply it to the function
represented by the jaxpr.
Another way to state the wrinkle is that initial-style primitives like lax.scan
rely on the ability to round-trip to a jaxpr and back to a Python callable while
preserving semantics. That must mean preserving custom differentiation rule
semantics too.
The solution is to use a bit of dynamic scoping: when we're staging out to a
jaxpr for an initial-style primitive, like those in lax_control_flow.py, we set
a bit on the global trace state. When that bit is set, instead of using the
final-style custom_jvp_call
primitive, we use an initial-style
custom_jvp_call_jaxpr
primitive, and trace the functions f
and f_jvp
to
jaxprs up-front to make initial-style processing easier. The
custom_jvp_call_jaxpr
primitive is otherwise similar to the final-style
version.
(Footnote: while morally we form jaxprs for both f
and f_jvp
before binding
custom_jvp_call_jaxpr
, we need to delay the formation of the jaxpr of f_jvp
because it may call the custom-JVP function and thus eager processing would lead
to an infinite recursion. We delay that jaxpr formation in a thunk.)
If we gave up on the Python flexibility
problem, we could get away with only having
custom_jvp_call_jaxpr
and not having the separate Python-level primitive
custom_jvp_call
.
The custom JVP for an a -> b
function is specified with an (a, Ta) -> (b, T b)
function:
# f :: a -> b
@jax.custom_jvp
def f(x):
return np.sin(x)
# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
x, = primals
t, = tangents
return f(x), np.cos(x) * t
f.defjvp(f_jvp)
(Interesting autodiff aside: for the rule to apply to higher-order
differentiation, one must call f
in the body of f_jvp
; that precludes some
kinds of work sharing between the internals of f
and the tangent calculation.)
The custom VJP for an a -> b
function is specified with an a -> (b, c)
forward
pass function paired with a (c, CT b) -> CT
a backward pass function:
# f :: a -> b
@jax.custom_vjp
def f(x):
return np.sin(x)
# f_fwd :: a -> (b, c)
def f_fwd(x):
return f(x), np.cos(x)
# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, g):
return (cos_x * g,)
f.defvjp(f_fwd, f_bwd)
The signature a -> (b, CT b --o CT a)
is more aesthetically pleasing, but
supporting it would make the implementation more complex and might require
compromising expressibility desiderata. The basic reason that Python callables
are opaque (unless we trace them to a jaxpr eagerly, which places expressiveness
constraints), and in this case we may be returning a callable with vmap
tracers
inside its closure that we need to know about during the forward pass.
We could add convenience wrappers, for example to define the JVP rule for a single argument at a time (like we do internally for primitives). But because this proposal is complicated enough as it is, I decided against convenience layers; let’s keep things minimal for now.
There are some other bells and whistles to the API:
- Inputs and output types
a
,b
, andc
can be arbitrary pytrees of jaxtypes. - Passing arguments by name (keyword arguments) is supported when they can be
resolved to positions using the
inspect
module. This is a bit of an experiment with Python 3’s improved ability to programmatically inspect argument signatures. I believe it is sound but not complete, which is a fine place to be. (See also #2069.) - Arguments can be marked non-differentiable using
nondiff_argnums
, and as withjit
’sstatic_argnums
these arguments don’t have to be JAX types. We need to set a convention for how these arguments are passed to the rules. For a primal function with type signature(d, a) -> b
whered
represents the non-differentiable type, the JVP rule’s signature is(a, T a, d) -> T b
and the VJP rule’s reverse component signature is(d, c, CT b) -> CT a
. That is, the non-differentiable arguments are passed in order afterprimals
andtangents
for a custom JVP rule, and passed in order preceding the residuals in a custom VJP rule’s reverse function.
- Updated
jax.experimental.odeint
- Since
odeint
is a pretty complex user of a custom VJP rule, in addition to just updating it to work at all, I wanted to revise it to be a canonical user of the new custom VJP API as a way to test that the API was a good one. - Along the way I made other improvements to the
odeint
implementation:- remove raveling/unraveling boilerplate
- make use of
lax.scan
to remove the index-update logic - speed up by 20+% on the simple pendulum benchmark
- Since
- Added a custom bind method on each transform for the custom derivative call
primitives,
custom_jvp_call
andcustom_vjp_call
. It’s likecore.call_bind
, except we don’t process env traces: those are just errors. - Added
custom_lin
primitive, which gets staged out into linear jaxprs to be transposed when using a custom VJP rule.- Because our reverse-mode autodiff is decomposed into linearization, partial evaluation, and transposition, our custom VJP rules are processed in two separate steps: one during linearization and one during transposition.
- The linearization step, i.e. the JVP rule for
custom_vjp_call
, appliescustom_lin
to the tangent values;custom_lin
carries with it the user’s custom backward-pass function, and as a primitive it only has a transpose rule. - This mechanism is described more in #636.
- To prevent