Skip to content

Commit

Permalink
add custom_jvp / vjp, delete custom_transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 22, 2020
1 parent 6876271 commit 7e480fa
Show file tree
Hide file tree
Showing 29 changed files with 4,414 additions and 1,551 deletions.
6 changes: 2 additions & 4 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,8 @@ script:
- if [ "$JAX_ONLY_DOCUMENTATION" = true ]; then
sphinx-build -b html -D nbsphinx_execute=always docs docs/build/html ;
elif [ "$JAX_ONLY_CHECK_TYPES" = true ]; then
echo "===== Checking with mypy ====" &&
time mypy --config-file=mypy.ini jax &&
echo "===== Checking with pytype ====" &&
time pytype jax ;
echo "===== Checking with mypy ===="
time mypy --config-file=mypy.ini jax ;
else
pytest -n 1 tests examples -W ignore ;
fi
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ print(grad(grad(grad(tanh)))(1.0))
For more advanced autodiff, you can use
[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for
reverse-mode vector-Jacobian products and
[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp) for
[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) for
forward-mode Jacobian-vector products. The two can be composed arbitrarily with
one another, and with other JAX transformations. Here's one way to compose those
to make a function that efficiently computes [full Hessian
Expand Down
472 changes: 472 additions & 0 deletions design_notes/custom_derivatives.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ jax 0.1.59 (February 11, 2020)
* Simplified :py:class:`Jaxpr` by removing the ``Jaxpr.freevars`` and
``Jaxpr.bound_subjaxprs``. The call primitives (``xla_call``, ``xla_pmap``,
``sharded_call``, and ``remat_call``) get a new parameter ``call_jaxpr`` with a
fully-closed (no ``constvars``) JAXPR. Also, added a new field ``call_primitive``
fully-closed (no ``constvars``) jaxpr. Also, added a new field ``call_primitive``
to primitives.
* New features:

Expand Down
10 changes: 4 additions & 6 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,20 @@ For an introduction to JAX, start at the

notebooks/quickstart
notebooks/autodiff_cookbook
Training a Simple Neural Network, with PyTorch Data Loading <https://github.com/google/jax/blob/master/docs/notebooks/Neural_Network_and_Data_Loading.ipynb>
notebooks/vmapped_log_probs
Training a Simple Neural Network, with Tensorflow Datasets Data Loading <https://github.com/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb>


.. toctree::
:maxdepth: 1
:caption: Advanced JAX Tutorials

notebooks/Common_Gotchas_in_JAX
notebooks/XLA_in_Python
notebooks/Custom_derivative_rules_for_Python_code
notebooks/JAX_pytrees
notebooks/XLA_in_Python
notebooks/How_JAX_primitives_work
notebooks/Writing_custom_interpreters_in_Jax.ipynb
Training a Simple Neural Network, with Tensorflow Datasets Data Loading <https://github.com/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb>
notebooks/maml
notebooks/score_matching
notebooks/vmapped_log_probs

.. toctree::
:maxdepth: 1
Expand Down
5 changes: 0 additions & 5 deletions docs/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@ Automatic differentiation
.. autofunction:: jvp
.. autofunction:: linearize
.. autofunction:: vjp
.. autofunction:: custom_transforms
.. autofunction:: defjvp
.. autofunction:: defjvp_all
.. autofunction:: defvjp
.. autofunction:: defvjp_all
.. autofunction:: custom_gradient


Expand Down
66 changes: 33 additions & 33 deletions docs/jaxpr.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Understanding JAXPR
Understanding jaxprs
====================

Updated: February 14, 2020 (for commit 9e6fe64).
Expand All @@ -8,29 +8,29 @@ Updated: February 14, 2020 (for commit 9e6fe64).

Conceptually, one can think of JAX transformations as first tracing the Python
function to be transformed into a small and well-behaved intermediate form,
the JAXPR, that is then transformed accordingly, and ultimately compiled and executed.
the jaxpr, that is then transformed accordingly, and ultimately compiled and executed.
One of the reasons JAX can pack so much power into such a small software package
is that it starts with a familiar and flexible programming interface (Python with NumPy)
and it uses the actual Python interpreter to do most of the heavy lifting to distill the
essence of the computation into a simple statically-typed expression language
with limited higher-order features: the JAXPR language.
with limited higher-order features: the jaxpr language.

Not all Python programs can be processed this way, but it turns out that many
scientific computing and machine learning programs do have this property.

Before we proceed, it is important to point out that not all JAX transformations
materialize a JAXPR as described above; some, e.g., differentiation,
materialize a jaxpr as described above; some, e.g., differentiation,
will apply transformations incrementally during tracing.
Nevertheless, if one wants to understand how JAX works internally, or to
make use of the result of JAX tracing, it is useful to understand JAXPR.
make use of the result of JAX tracing, it is useful to understand jaxpr.

A JAXPR instance represents a function with one of more typed parameters (input variables)
A jaxpr instance represents a function with one of more typed parameters (input variables)
and one or more typed results. The results depend only on the input
variables; there are no free variables captured from enclosing scopes.
The inputs and outputs have types, which in JAX are represented as abstract
values. There are two related representations in the code for JAXPRs. The main
values. There are two related representations in the code for jaxprs. The main
one is :py:class:`jax.core.TypedJaxpr` and is what you obtain when you
use :py:func:`jax.make_jaxpr` to inspect JAXPRs. It has the following
use :py:func:`jax.make_jaxpr` to inspect jaxprs. It has the following
fields:

* ``jaxpr``: is the actual computation content of the actual function (described below).
Expand All @@ -49,20 +49,20 @@ The most interesting part of the TypedJaxpr is the actual execution content,
represented as a :py:class:`jax.core.Jaxpr` as printed using the following
grammar::

JAXPR ::= { lambda Var* ; Var+.
jaxpr ::= { lambda Var* ; Var+.
let Eqn*
in [Expr+] }

where:
* The parameter of the JAXPR are shown as two lists of variables separated by
* The parameter of the jaxpr are shown as two lists of variables separated by
``;``. The first set of variables are the ones that have been introduced
to stand for constants that have been hoisted out. These are called the
`constvars`. The second list of variables are the real input variables.
* ``Eqn*`` is a list of equations, defining intermediate variables referring to
intermediate expressions. Each equation defines one or more variables as the
result of applying a primitive on some atomic expressions. Each equation uses only
input variables and intermediate variables defined by previous equations.
* ``Expr+``: is a list of output atomic expressions for the JAXPR.
* ``Expr+``: is a list of output atomic expressions for the jaxpr.

Equations are printed as follows::

Expand All @@ -79,14 +79,14 @@ where:
square brackets. Each parameter is shown as ``Name = Value``.


Most JAXPR primitives are first-order (they take just one or more Expr as arguments)::
Most jaxpr primitives are first-order (they take just one or more Expr as arguments)::

Primitive := add | sub | sin | mul | ...


The JAXPR primitives are documented in the :py:mod:`jax.lax` module.
The jaxpr primitives are documented in the :py:mod:`jax.lax` module.

For example, here is the JAXPR produced for the function ``func1`` below::
For example, here is the jaxpr produced for the function ``func1`` below::

from jax import numpy as jnp
def func1(first, second):
Expand All @@ -110,12 +110,12 @@ The ``reduce_sum`` primitive has named parameters ``axes`` and ``input_shape``,
addition to the operand ``e``.

Note that JAX traces through Python-level control-flow and higher-order functions
when it extracts the JAXPR. This means that just because a Python program contains
functions and control-flow, the resulting JAXPR does not have
when it extracts the jaxpr. This means that just because a Python program contains
functions and control-flow, the resulting jaxpr does not have
to contain control-flow or higher-order features.
For example, when tracing the function ``func3`` JAX will inline the call to
``inner`` and the conditional ``if second.shape[0] > 4``, and will produce the same
JAXPR as before::
jaxpr as before::

def func2(inner, first, second):
temp = first + inner(second) * 3.
Expand All @@ -142,13 +142,13 @@ JAXPR as before::
Handling PyTrees
----------------

In JAXPR there are no tuple types; instead primitives take multiple inputs
In jaxpr there are no tuple types; instead primitives take multiple inputs
and produce multiple outputs. When processing a function that has structured
inputs or outputs, JAX will flatten those and in JAXPR they will appear as lists
inputs or outputs, JAX will flatten those and in jaxpr they will appear as lists
of inputs and outputs. For more details, please see the documentation for
PyTrees (:doc:`notebooks/JAX_pytrees`).

For example, the following code produces an identical JAXPR to what we saw
For example, the following code produces an identical jaxpr to what we saw
before (with two input vars, one for each element of the input tuple)::


Expand Down Expand Up @@ -184,7 +184,7 @@ from the Python program, or from constant-folding. For example, the function
print(api.make_jaxpr(func6)(jnp.ones(8)))


JAX produces the following JAXPR::
JAX produces the following jaxpr::

{ lambda b d a.
let c = add a b
Expand All @@ -196,13 +196,13 @@ When tracing ``func6``, the function ``func5`` is invoked with a constant value
``jnp.sin(second) * 3.`` is constant-folded.
There are two ConstVars, ``b`` (standing for ``jnp.sin(second) * 3.``) and ``d``
(standing for ``jnp.ones(8)``). Unfortunately, it is not easy to tell from the
JAXPR notation what constants the constant variables stand for.
jaxpr notation what constants the constant variables stand for.

Higher-order primitives
-----------------------

JAXPR includes several higher-order primitives. They are more complicated because
they include sub-JAXPRs.
jaxpr includes several higher-order primitives. They are more complicated because
they include sub-jaxprs.

Cond
^^^^
Expand Down Expand Up @@ -238,7 +238,7 @@ For example::

The cond primitive has a number of parameters:

* `true_jaxpr` and `false_jaxpr` are JAXPRs that correspond to the true
* `true_jaxpr` and `false_jaxpr` are jaxprs that correspond to the true
and false branch functionals. In this example, those functionals take each
one input variable, corresponding to ``xtrue`` and ``xfalse`` respectively.
* `linear` is a tuple of booleans that is used internally by the auto-differentiation
Expand Down Expand Up @@ -273,7 +273,7 @@ contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar`::
in a } ] d b c e b c
in f }

The top-level JAXPR has one `constvar` ``e`` (corresponding to ``jnp.ones(1)`` from the
The top-level jaxpr has one `constvar` ``e`` (corresponding to ``jnp.ones(1)`` from the
body of the ``false_jaxpr``) and three input variables ``a b c`` (corresponding to ``arg1``
and the two elements of ``arg2``; note that ``arg2`` has been flattened).
The ``true_jaxpr`` has two input variables (corresponding to the two elements of ``arg2``
Expand All @@ -286,10 +286,10 @@ The actual operands to the cond primitive are: ``d b c e b c``, which correspond

* 1 operand for the predicate,
* 2 operands for ``true_jaxpr``, i.e., ``b`` and ``c``, which are input vars,
corresponding to ``arg2`` for the top-level JAXPR,
* 1 constant for ``false_jaxpr``, i.e., ``e``, which is a consvar for the top-level JAXPR,
corresponding to ``arg2`` for the top-level jaxpr,
* 1 constant for ``false_jaxpr``, i.e., ``e``, which is a consvar for the top-level jaxpr,
* 2 operands for ``true_jaxpr``, i.e., ``b`` and ``c``, which are the input vars
corresponding to ``arg2`` for the top-level JAXPR.
corresponding to ``arg2`` for the top-level jaxpr.

While
^^^^^
Expand Down Expand Up @@ -328,7 +328,7 @@ For example, here is an example fori loop::
cond_nconsts=0 ] c a 0 b e
in h }

The top-level JAXPR has two constvars: ``c`` (corresponding to ``ones * 3.`` from the body
The top-level jaxpr has two constvars: ``c`` (corresponding to ``ones * 3.`` from the body
of the loop) and ``d`` (corresponding to the use of ``ones`` in the initial carry).
There are also two input variables (``a`` corresponding to ``arg`` and ``b`` corresponding
to ``n``).
Expand Down Expand Up @@ -386,7 +386,7 @@ For the example consider the function ``func11`` below::
num_consts=1 ] b 0.0 a * c
in (d, e) }

The top-level JAXPR has one constvar ``c`` corresponding to the ``ones`` constant,
The top-level jaxpr has one constvar ``c`` corresponding to the ``ones`` constant,
and two input variables corresponding to the arguments ``arr`` and ``extra``.
The body of the scan has 5 input variables, of which:

Expand All @@ -413,7 +413,7 @@ XLA_call
^^^^^^^^

The call primitive arises from JIT compilation, and it encapsulates
a sub-JAXPR along with parameters the specify the backend and the device the
a sub-jaxpr along with parameters the specify the backend and the device the
computation should run. For example::

def func12(arg):
Expand All @@ -438,7 +438,7 @@ computation should run. For example::
The top-level constvar ``b`` refers to the ``jnp.ones(1)`` constant, and
the top-level input variable `a` refers to the ``arg`` parameter of ``func12``.
The ``xla_call`` primitive stands for a call to the jitted ``inner`` function.
The primitive has the function body in the ``call_jaxpr`` parameter, a JAXPR
The primitive has the function body in the ``call_jaxpr`` parameter, a jaxpr
with 3 input parameters:

* ``c`` is a constvar and stands for the ``ones`` constant,
Expand Down
Loading

0 comments on commit 7e480fa

Please sign in to comment.