Skip to content

Commit

Permalink
Merge pull request jax-ml#16050 from patrick-kidger:linearize-aux
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 540520135
  • Loading branch information
jax authors committed Jun 15, 2023
2 parents 6d7da07 + f2d64f6 commit 94674b9
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 12 deletions.
53 changes: 41 additions & 12 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
shaped_abstractify, _ensure_str_tuple,
shaped_abstractify, _ensure_str_tuple, apply_flat_fun_nokwargs,
check_callable, debug_info, result_paths, flat_out_axes, debug_info_final)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
Expand Down Expand Up @@ -1963,7 +1963,18 @@ def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False):
tree_unflatten(out_tree, out_tangents),
tree_unflatten(aux_tree, aux()))

def linearize(fun: Callable, *primals) -> Tuple[Any, Callable]:
@overload
def linearize(fun: Callable, *primals, has_aux: Literal[False] = False
) -> Tuple[Any, Callable]:
...

@overload
def linearize(fun: Callable, *primals, has_aux: Literal[True]
) -> Tuple[Any, Callable, Any]:
...

def linearize(fun: Callable, *primals, has_aux: bool = False
) -> Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]:
"""Produces a linear approximation to ``fun`` using :py:func:`jvp` and partial eval.
Args:
Expand All @@ -1974,12 +1985,17 @@ def linearize(fun: Callable, *primals) -> Tuple[Any, Callable]:
evaluated. Should be a tuple of arrays, scalar, or standard Python
container thereof. The length of the tuple is equal to the number of
positional parameters of ``fun``.
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the first
element is considered the output of the mathematical function to be linearized,
and the second is auxiliary data. Default False.
Returns:
A pair where the first element is the value of ``f(*primals)`` and the
second element is a function that evaluates the (forward-mode)
Jacobian-vector product of ``fun`` evaluated at ``primals`` without re-doing
the linearization work.
If ``has_aux`` is ``False``, returns a pair where the first element is the value of
``f(*primals)`` and the second element is a function that evaluates the
(forward-mode) Jacobian-vector product of ``fun`` evaluated at ``primals`` without
re-doing the linearization work. If ``has_aux`` is ``True``, returns a
``(primals_out, lin_fn, aux)`` tuple where ``aux`` is the auxiliary data returned by
``fun``.
In terms of values computed, :py:func:`linearize` behaves much like a curried
:py:func:`jvp`, where these two code blocks compute the same values::
Expand Down Expand Up @@ -2026,16 +2042,29 @@ def linearize(fun: Callable, *primals) -> Tuple[Any, Callable]:
"""
check_callable(fun)
f = lu.wrap_init(fun)
primals_flat, in_tree = tree_flatten((primals, {}))
jaxtree_fun, out_tree = flatten_fun(f, in_tree)
out_primals, out_pvals, jaxpr, consts = ad.linearize(jaxtree_fun, *primals_flat)
out_tree = out_tree()
primals_flat, in_tree = tree_flatten(primals)
if has_aux:
jaxtree_fun, out_tree = flatten_fun_nokwargs2(f, in_tree)
else:
jaxtree_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
out_primals, out_pvals, jaxpr, consts, *maybe_aux = ad.linearize(jaxtree_fun,
*primals_flat,
has_aux=has_aux)
if has_aux:
out_tree, aux_tree = out_tree()
else:
out_tree = out_tree()
out_primal_py = tree_unflatten(out_tree, out_primals)
primal_avals = list(map(core.get_aval, primals_flat))
# Ensure that lifted_jvp is a PyTree
lifted_jvp = Partial(partial(_lift_linearized, jaxpr, primal_avals,
(in_tree, out_tree), out_pvals), consts)
return out_primal_py, lifted_jvp
if has_aux:
[aux] = maybe_aux
return out_primal_py, lifted_jvp, tree_unflatten(aux_tree, aux)
else:
[] = maybe_aux
return out_primal_py, lifted_jvp

def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args):
def fun(*tangents):
Expand All @@ -2052,7 +2081,7 @@ def fun(*tangents):
assert next(tangents_out_, None) is None
return full_out

return apply_flat_fun(fun, io_tree, *py_args)
return apply_flat_fun_nokwargs(fun, io_tree, py_args)

def _vjp_pullback_wrapper(name, cotangent_dtypes, cotangent_shapes, io_tree,
fun, *py_args_):
Expand Down
11 changes: 11 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3535,6 +3535,17 @@ def f():

f() # doesn't crash

def test_linearize_aux(self):
def fn(x):
return x * 2 - 3, x > 0

f, lin_fn, aux = api.linearize(fn, 3.4, has_aux=True)
tang = lin_fn(5.)

self.assertAllClose(f, 3.8)
self.assertAllClose(tang, 10.)
self.assertEqual(aux, True)

def test_linearize_aval_error(self):
# https://github.com/google/jax/issues/4622
f = lambda x: x
Expand Down

0 comments on commit 94674b9

Please sign in to comment.