Skip to content

Commit

Permalink
Merge pull request jax-ml#10668 from sharadmv:custom-interpreter-update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 448058058
  • Loading branch information
jax authors committed May 11, 2022
2 parents 43467f9 + aca9dc6 commit d092d63
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 28 deletions.
25 changes: 11 additions & 14 deletions docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@
"source": [
"To get a first look at Jaxprs, consider the `make_jaxpr` transformation. `make_jaxpr` is essentially a \"pretty-printing\" transformation:\n",
"it transforms a function into one that, given example arguments, produces a Jaxpr representation of its computation.\n",
"Although we can't generally use the Jaxprs that it returns, it is useful for debugging and introspection.\n",
"Let's use it to look at how some example Jaxprs\n",
"are structured."
"`make_jaxpr` is useful for debugging and introspection.\n",
"Let's use it to look at how some example Jaxprs are structured."
]
},
{
Expand Down Expand Up @@ -201,7 +200,7 @@
"\n",
"### 1. Tracing a function\n",
"\n",
"We can't use `make_jaxpr` for this, because we need to pull out constants created during the trace to pass into the Jaxpr. However, we can write a function that does something very similar to `make_jaxpr`."
"Let's use `make_jaxpr` to trace a function into a Jaxpr."
]
},
{
Expand All @@ -227,8 +226,8 @@
"id": "CpTml2PTrzZ4"
},
"source": [
"This function first flattens its arguments into a list, which are the abstracted and wrapped as partial values. The `jax.make_jaxpr` function is used to then trace a function into a Jaxpr\n",
"from a list of partial value inputs."
"`jax.make_jaxpr` returns a *closed* Jaxpr, which is a Jaxpr that has been bundled with\n",
"the constants (`literals`) from the trace."
]
},
{
Expand All @@ -243,7 +242,7 @@
" return jnp.exp(jnp.tanh(x))\n",
"\n",
"closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))\n",
"print(closed_jaxpr)\n",
"print(closed_jaxpr.jaxpr)\n",
"print(closed_jaxpr.literals)"
]
},
Expand Down Expand Up @@ -321,7 +320,7 @@
"source": [
"Notice that `eval_jaxpr` will always return a flat list even if the original function does not.\n",
"\n",
"Furthermore, this interpreter does not handle `subjaxprs`, which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover."
"Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover."
]
},
{
Expand Down Expand Up @@ -389,9 +388,8 @@
"def inverse(fun):\n",
" @wraps(fun)\n",
" def wrapped(*args, **kwargs):\n",
" # Since we assume unary functions, we won't\n",
" # worry about flattening and\n",
" # unflattening arguments\n",
" # Since we assume unary functions, we won't worry about flattening and\n",
" # unflattening arguments.\n",
" closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)\n",
" out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)\n",
" return out[0]\n",
Expand Down Expand Up @@ -434,9 +432,8 @@
" # outvars are now invars \n",
" invals = safe_map(read, eqn.outvars)\n",
" if eqn.primitive not in inverse_registry:\n",
" raise NotImplementedError(\"{} does not have registered inverse.\".format(\n",
" eqn.primitive\n",
" ))\n",
" raise NotImplementedError(\n",
" f\"{eqn.primitive} does not have registered inverse.\")\n",
" # Assuming a unary function \n",
" outval = inverse_registry[eqn.primitive](*invals)\n",
" safe_map(write, eqn.invars, [outval])\n",
Expand Down
25 changes: 11 additions & 14 deletions docs/notebooks/Writing_custom_interpreters_in_Jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ for function transformation.

To get a first look at Jaxprs, consider the `make_jaxpr` transformation. `make_jaxpr` is essentially a "pretty-printing" transformation:
it transforms a function into one that, given example arguments, produces a Jaxpr representation of its computation.
Although we can't generally use the Jaxprs that it returns, it is useful for debugging and introspection.
Let's use it to look at how some example Jaxprs
are structured.
`make_jaxpr` is useful for debugging and introspection.
Let's use it to look at how some example Jaxprs are structured.

```{code-cell} ipython3
:id: RSxEiWi-EeYW
Expand Down Expand Up @@ -139,7 +138,7 @@ The way we'll implement this is by (1) tracing `f` into a Jaxpr, then (2) interp

### 1. Tracing a function

We can't use `make_jaxpr` for this, because we need to pull out constants created during the trace to pass into the Jaxpr. However, we can write a function that does something very similar to `make_jaxpr`.
Let's use `make_jaxpr` to trace a function into a Jaxpr.

```{code-cell} ipython3
:id: BHkg_3P1pXJj
Expand All @@ -155,8 +154,8 @@ from jax._src.util import safe_map

+++ {"id": "CpTml2PTrzZ4"}

This function first flattens its arguments into a list, which are the abstracted and wrapped as partial values. The `jax.make_jaxpr` function is used to then trace a function into a Jaxpr
from a list of partial value inputs.
`jax.make_jaxpr` returns a *closed* Jaxpr, which is a Jaxpr that has been bundled with
the constants (`literals`) from the trace.

```{code-cell} ipython3
:id: Tc1REN5aq_fH
Expand All @@ -165,7 +164,7 @@ def f(x):
return jnp.exp(jnp.tanh(x))
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
print(closed_jaxpr)
print(closed_jaxpr.jaxpr)
print(closed_jaxpr.literals)
```

Expand Down Expand Up @@ -224,7 +223,7 @@ eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))

Notice that `eval_jaxpr` will always return a flat list even if the original function does not.

Furthermore, this interpreter does not handle `subjaxprs`, which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover.
Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover.

+++ {"id": "0vb2ZoGrCMM4"}

Expand Down Expand Up @@ -261,9 +260,8 @@ inverse_registry[lax.tanh_p] = jnp.arctanh
def inverse(fun):
@wraps(fun)
def wrapped(*args, **kwargs):
# Since we assume unary functions, we won't
# worry about flattening and
# unflattening arguments
# Since we assume unary functions, we won't worry about flattening and
# unflattening arguments.
closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
return out[0]
Expand Down Expand Up @@ -296,9 +294,8 @@ def inverse_jaxpr(jaxpr, consts, *args):
# outvars are now invars
invals = safe_map(read, eqn.outvars)
if eqn.primitive not in inverse_registry:
raise NotImplementedError("{} does not have registered inverse.".format(
eqn.primitive
))
raise NotImplementedError(
f"{eqn.primitive} does not have registered inverse.")
# Assuming a unary function
outval = inverse_registry[eqn.primitive](*invals)
safe_map(write, eqn.invars, [outval])
Expand Down

0 comments on commit d092d63

Please sign in to comment.