Skip to content

Commit

Permalink
write in-process AOT walkthrough doc
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Sep 2, 2022
1 parent 43db064 commit bb68fbe
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 0 deletions.
236 changes: 236 additions & 0 deletions docs/aot.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# Ahead-of-time lowering and compilation

JAX offers several transformations, such as `jax.jit` and `jax.pmap`, returning
a function that is compiled and runs on accelerators or the CPU. As the JIT
acronym indicates, all compilation happens _just-in-time_ for execution.

Some situations call for _ahead-of-time_ (AOT) compilation instead. When you
want to fully compile prior to execution time, or you want control over when
different parts of the compilation process take place, JAX has some options for
you.

First, let's review the stages of compilation. Suppose that `f` is a
function/callable output by {func}`jax.jit`, say `f = jax.jit(F)` for some input
callable `F`. When it is invoked with arguments, say `f(x, y)` where `x` and `y`
are arrays, JAX does the following in order:

1. **Stage out** a specialized version of the original Python callable `F` to an
internal representation. The specialization reflects a restriction of `F` to
input types inferred from properties of the arguments `x` and `y` (usually
their shape and element type).

2. **Lower** this specialized, staged-out computation to the XLA compiler's
input language, MHLO.

3. **Compile** the lowered HLO program to produce an optimized executable for
the target device (CPU, GPU, or TPU).

4. **Execute** the compiled executable with the arrays `x` and `y` as arguments.

JAX's AOT API gives you direct control over steps #2, #3, and #4 (but [not
#1](#inspecting-staged-out-computations)), plus some other features along the
way. An example:

```python
>>> import jax
>>> import jax.numpy as jnp
>>> import numpy as np

>>> def f(x, y): return 2 * x + y
>>> x, y = 3, 4

>>> lowered = jax.jit(f).lower(x, y)

>>> # Print lowered HLO
>>> print(lowered.as_text())
module @jit_f.0 {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = mhlo.constant dense<2> : tensor<i32>
%1 = mhlo.multiply %0, %arg0 : tensor<i32>
%2 = mhlo.add %1, %arg1 : tensor<i32>
return %2 : tensor<i32>
}
}

>>> compiled = lowered.compile()

>>> # Query for cost analysis, print FLOP estimate
>>> compiled.cost_analysis()[0]['flops']
2.0

>>> # Execute the compiled function!
>>> compiled(x, y)
DeviceArray(10, dtype=int32)
```

See the {mod}`jax.stages` documentation for more details on what functionality
the lowering and compiled functions provide.

In place of `jax.jit` above, you can also `lower(...)` the result of
{func}`jax.pmap`, as well as `pjit` and `xmap` (from
{mod}`jax.experimental.pjit` and {mod}`jax.experimental.maps` respectively). In
each case, you can `compile()` the result similarly.

All optional arguments to `jit`---such as `static_argnums`---are respected in
the corresponding lowering, compilation, and execution. Again the same goes for
`pmap`, `pjit`, and `xmap`.

In the example above, we can replace the arguments to `lower` with any objects
that have `shape` and `dtype` attributes:

```python
>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y)
DeviceArray(10, dtype=int32)
```

More generally, `lower` only needs its arguments to structurally supply what JAX
must know for specialization and lowering. For typical array arguments like the
ones above, this means `shape` and `dtype` fields. For static arguments, by
contrast, JAX needs actual array values (more on this
[below](#lowering-with-static-arguments)).

Invoking an AOT-compiled function with arguments that are incompatible with its
lowering raises an error:

```python
>>> x_1d = y_1d = jnp.arange(3)
>>> jax.jit(f)(i32_scalar, i32_scalar).compile(x_1d, y_1d)
...
TypeError: Computation compiled for input types:
ShapedArray(int32[]), ShapedArray(int32[])
called with:
ShapedArray(int32[3]), ShapedArray(int32[3])

>>> x_f = y_f = 72.0
>>> jax.jit(f)(i32_scalar, i32_scalar).compile(x_f, y_f)
...
TypeError: Computation compiled for input types:
ShapedArray(int32[]), ShapedArray(int32[])
called with:
ShapedArray(float32[]), ShapedArray(float32[])
```

Relatedly, AOT-compiled functions [cannot be transformed by JAX's just-in-time
transformations](#aot-compiled-functions-cannot-be-transformed) such as
`jax.jit`, {func}`jax.grad`, and {func}`jax.vmap`.


## Lowering with static arguments

Lowering with static arguments underscores the interaction between options
passed to `jax.jit`, the arguments passed to `lower`, and the arguments needed
to invoke the resulting compiled function. Continuing with our example above:

```python
>>> lowered_with_x = jax.jit(f, static_argnums=0).lower(7, 8)

>>> # Lowered HLO, specialized to the *value* of the first argument (7)
>>> print(lowered_with_x.as_text())
module @jit_f.1 {
func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
%0 = mhlo.constant dense<14> : tensor<i32>
%1 = mhlo.add %0, %arg0 : tensor<i32>
return %1 : tensor<i32>
}
}
>>> lowered_with_x.compile()(5)
DeviceArray(19, dtype=int32)
```

Note that `lower` here takes two arguments as usual, but the subsequent compiled
function accepts only the remaining non-static second argument. The static first
argument (value 7) is taken as a constant at lowering time and built into the
lowered computation, where it is possibly folded in with other constants. In
this case, its multiplication by 2 is simplified, resulting in the constant 14.

Although the second argument to `lower` above can be replaced by a hollow
shape/dtype structure, it is necessary that the static first argument be a
concrete value. Otherwise, lowering would err:

```python
>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar)
TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'

>>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5)
DeviceArray(25, dtype=int32)
```

## AOT-compiled functions cannot be transformed

Compiled functions are specialized to a particular set of argument "types," such
as arrays with a specific shape and element type in our running example. From
JAX's internal point of view, transformations such as {func}`jax.vmap` alter the
type signature of functions in a way that invalidates the compiled-for type
signature. As a policy, JAX simply disallows compiled functions to be involved
in transformations. Example:

```python
>>> def g(x):
... assert x.shape == (3, 2)
... return x @ jnp.ones(2)

>>> def make_z(*shape):
... return jnp.arange(np.prod(shape)).reshape(shape)

>>> z, zs = make_z(3, 2), make_z(4, 3, 2)

>>> g_jit = jax.jit(g)
>>> g_aot = jax.jit(g).lower(z).compile()

>>> jax.vmap(g_jit)(zs)
DeviceArray([[ 1., 5., 9.],
[13., 17., 21.],
[25., 29., 33.],
[37., 41., 45.]], dtype=float32)

>>> jax.vmap(g_aot)(zs)
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax.interpreters.batching.BatchTracer'>.
```

A similar error is raised when `g_aot` is involved in autodiff
(e.g. {func}`jax.grad`). For consistency, transformation by `jax.jit` is
disallowed as well, even though `jit` does not meaningfully modify its
argument's type signature.


## Debug information and analyses, when available

In addition to the primary AOT functionality (separate and explicit lowering,
compilation, and execution), JAX's various AOT stages also offer some additional
features to help with debugging and gathering compiler feedback.

For instance, as the initial example above shows, lowered functions often offer
a text representation. Compiled functions do the same, and also offer cost and
memory analyses from the compiler. All of these are provided via methods on the
{class}`jax.stages.Lowered` and {class}`jax.stages.Compiled` objects (e.g.,
`lowered.as_text()` and `compiled.cost_analysis()` above).

These methods are meant as an aid for manual inspection and debugging, not as a
reliably programmable API. Their availability and output vary by compiler,
platform, and runtime. This makes for two important caveats:

1. If some functionality is unavailable on JAX's current backend, then the
method for it returns something trivial (and `False`-like). For example, if
the compiler underlying JAX does not provide a cost analysis, then
`compiled.cost_analysis()` will be `None`.

2. If some functionality is available, there are still very limited guarantees
on what the corresponding method provides. The return value is not required
to be consistent---in type, structure, or value---across JAX configurations,
backends/platforms, versions, or even invocations of the method. JAX cannot
guarantee that the output of `compiled.cost_analysis()` on one day will
remain the same on the following day.

When in doubt, see the package API documentation for {mod}`jax.stages`.


## Inspecting staged-out computations

Stage #1 in the list at the top of this note mentions specialization and
staging, prior to lowering. JAX's internal notion of a function specialized to
the types of its arguments is not always a reified data structure in memory. To
explicitly construct a view of JAX's specialization of a function in the
internal [Jaxpr intermediate
language](https://jax.readthedocs.io/en/latest/jaxpr.html), see
{func}`jax.make_jaxpr`.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ parallelize, Just-In-Time compile to GPU/TPU, and more.

faq
async_dispatch
aot
jaxpr
notebooks/convolutions
pytrees
Expand Down
2 changes: 2 additions & 0 deletions jax/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
``jax.jit`` and ``jax.pmap``, also support a common means of explicit
lowering and compilation *ahead of time*. This module defines types
that represent the stages of this process.
For more, see the `AOT walkthrough <https://jax.readthedocs.io/en/latest/aot.html>`_.
"""

from jax._src.stages import (
Expand Down

0 comments on commit bb68fbe

Please sign in to comment.