Skip to content

Commit

Permalink
revamp custom_jvp/vjp implementation to fix bugs
Browse files Browse the repository at this point in the history
Co-authored-by: Dougal Maclaurin <[email protected]>
  • Loading branch information
mattjj and dougalm committed Mar 30, 2020
1 parent 67283a0 commit 6193e5e
Show file tree
Hide file tree
Showing 13 changed files with 373 additions and 286 deletions.
65 changes: 15 additions & 50 deletions design_notes/custom_derivatives.md
Original file line number Diff line number Diff line change
Expand Up @@ -317,47 +317,25 @@ rely on the ability to round-trip to a jaxpr and back to a Python callable while
preserving semantics. That must mean preserving custom differentiation rule
semantics too.

The solution is for the partial evaluation rule for `custom_jvp_call` to stage
out an initial-style call-like primitive that can be still be processed
correctly by `eval`, `jit`, `jvp` and/or `vmap` transformations. That means a
staged-out call-like primitive that carries with it enough information about `f`
and `f_jvp` to support all these transformations. We refer to this additional
primitive as `custom_jvp_call_jaxpr`. It is similar to `custom_jvp_call` except
it’s parameterized by a jaxpr for the primal function f rather than a Python
callable. The jaxpr for `f` is formed up-front before binding the primitive,
similar to other initial-style primitives.

(Three footnotes. First, we could refer to both the Python trace-time primitive
`custom_jvp_call`, which takes a wrapped Python callable as an argument, and the
jaxpr language primitive `custom_jvp_call_jaxpr`, which has a jaxpr as a
parameter, as simply "`custom_jvp_call`", analogously to how we refer to both
versions of `xla_call` as just "`xla_call`", but here we chose to use different
names to make the distinction more explicit. Second, for implementation
simplicity, both `custom_jvp_call` and `custom_jvp_call_jaxpr` have partial eval
rules that don’t do any nontrivial partial evaluation and instead stage
everything out. That doesn’t constrain automatic differentiation because
`custom_jvp_call_jaxpr`'s JVP rule doesn’t itself bind a call primitive but
instead just invokes the custom JVP rule callable. Third, we don’t form a jaxpr
for the JVP rule callable up-front, and instead keep it as a Python callable, to
avoid a recursion problem: in the common case that the JVP rule itself calls the
underlying custom-JVP function, we can’t trace the JVP rule up-front without
getting an infinite recursion. By not forming a jaxpr, we’re solving this in the
same way we always do: rules are Python callbacks invoked when a transformation
is applied, not part of the primitive, and though the rule here is associated
directly with the primitive, rather than being in a global dict, that’s just an
implementation detail.)
The solution is to use a bit of dynamic scoping: when we're staging out to a
jaxpr for an initial-style primitive, like those in lax_control_flow.py, we set
a bit on the global trace state. When that bit is set, instead of using the
final-style `custom_jvp_call` primitive, we use an initial-style
`custom_jvp_call_jaxpr` primitive, and trace the functions `f` and `f_jvp` to
jaxprs up-front to make initial-style processing easier. The
`custom_jvp_call_jaxpr` primitive is otherwise similar to the final-style
version.

(Footnote: while morally we form jaxprs for both `f` and `f_jvp` before binding
`custom_jvp_call_jaxpr`, we need to delay the formation of the jaxpr of `f_jvp`
because it may call the custom-JVP function and thus eager processing would lead
to an infinite recursion. We delay that jaxpr formation in a thunk.)

If we gave up on [the Python flexibility
problem](the-python-flexibility-problem), we could get away with only having
`custom_jvp_call_jaxpr` and not having the separate Python-level primitive
`custom_jvp_call`. One way to view the relationship between the two primitives
is in this schematic:
`custom_jvp_call`.

<div align="center">
<img
src="https://raw.githubusercontent.com/google/jax/master/images/custom_jvp_schematic.png"
alt="schematic"></img>
</div>

## API

Expand Down Expand Up @@ -456,17 +434,4 @@ There are some other bells and whistles to the API:
custom backward-pass function, and as a primitive it only has a transpose
rule.
* This mechanism is described more in [#636](https://github.com/google/jax/issues/636).
* Added a variant of `transformation_with_aux` called
`transformation_with_equal_aux` to allow repeated stores of equal values due
to running the same function multiple times.
* The custom rules functions, like `f_jvp` and `f_fwd`/`f_bwd` in the examples
above, are not “linear” in the sense of linear_util.py when used in
`custom_jvp_call_jaxpr` and `custom_vjp_call_jaxpr`, respectively. They may be
invoked multiple times as a jaxpr is processed in initial style. It’s
usually fine for rules to be invoked multiple times, but these rules must
plumb aux data out to the api.py-level caller, namely output pytree aux
data.
* (Recall from a footnote above that we can’t solve this by forming jaxprs for
the rules up-front because that can lead to infinite recursion.)


* To prevent
Binary file removed images/custom_jvp_schematic.png
Binary file not shown.
2 changes: 1 addition & 1 deletion jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,7 @@ def jaxpr_maker(*args, **kwargs):
jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)
in_pvals = map(pv_like, jax_args)
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
jaxtree_fun, in_pvals, instantiate=True, stage_out_calls=True)
jaxtree_fun, in_pvals, instantiate=True, stage_out=True)
out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
in_avals = tuple(raise_to_shaped(in_aval) for in_aval, _ in in_pvals)
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
Expand Down
9 changes: 9 additions & 0 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,15 +511,18 @@ class Sublevel(int): pass
class TraceState(threading.local):
trace_stack: TraceStack
substack: List[Sublevel]
initial_style: bool

def __init__(self) -> None:
self.trace_stack = TraceStack()
self.substack = [Sublevel(0)]
self.initial_style = False

def copy(self):
new = TraceState()
new.trace_stack = self.trace_stack.copy()
new.substack = self.substack[:]
new.initial_style = self.initial_style
return new
trace_state = TraceState()

Expand Down Expand Up @@ -574,6 +577,12 @@ def find_top_trace(xs):
else:
return type(top_trace)(top_trace.master, cur_sublevel())

@contextmanager
def initial_style_staging():
prev, trace_state.initial_style = trace_state.initial_style, True
yield
trace_state.initial_style = prev


# -------------------- abstract values --------------------

Expand Down
Loading

0 comments on commit 6193e5e

Please sign in to comment.