Skip to content

Commit

Permalink
DOC: many small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
elliotwaite committed Aug 4, 2021
1 parent df103f7 commit 7392a57
Show file tree
Hide file tree
Showing 65 changed files with 199 additions and 199 deletions.
8 changes: 4 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.

* Bug fixes:
* Tightened the checks for lax.argmin and lax.argmax to ensure they are
not used with invalid `axis` value, or with an empty reduction dimension.
not used with an invalid `axis` value, or with an empty reduction dimension.
({jax-issue}`#7196`)


Expand Down Expand Up @@ -333,7 +333,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Bug fixes:
* `jax.numpy.arccosh` now returns the same branch as `numpy.arccosh` for
complex inputs ({jax-issue}`#5156`)
* `host_callback.id_tap` now works for `jax.pmap` also. There is a
* `host_callback.id_tap` now works for `jax.pmap` also. There is an
optional parameter for `id_tap` and `id_print` to request that the
device from which the value is tapped be passed as a keyword argument
to the tap function ({jax-issue}`#5182`).
Expand All @@ -359,7 +359,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* New features:
* Add `jax.device_put_replicated`
* Add multi-host support to `jax.experimental.sharded_jit`
* Add support for differentiating eigenvaleus computed by `jax.numpy.linalg.eig`
* Add support for differentiating eigenvalues computed by `jax.numpy.linalg.eig`
* Add support for building on Windows platforms
* Add support for general in_axes and out_axes in `jax.pmap`
* Add complex support for `jax.numpy.linalg.slogdet`
Expand Down Expand Up @@ -504,7 +504,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.73...jax-v0.1.74).
* New Features:
* BFGS (#3101)
* TPU suppot for half-precision arithmetic (#3878)
* TPU support for half-precision arithmetic (#3878)
* Bug Fixes:
* Prevent some accidental dtype warnings (#3874)
* Fix a multi-threading bug in custom derivatives (#3845, #3869)
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ or the [examples](https://github.com/google/jax/tree/main/examples).
## Transformations

At its core, JAX is an extensible system for transforming numerical functions.
Here are four of primary interest: `grad`, `jit`, `vmap`, and `pmap`.
Here are four transformations of primary interest: `grad`, `jit`, `vmap`, and
`pmap`.

### Automatic differentiation with `grad`

Expand Down
2 changes: 1 addition & 1 deletion design_notes/omnistaging.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ The name "omnistaging" means staging out everything possible.

### Toy example

iJAX transformations like `jit` and `pmap` stage out computations to XLA. That
JAX transformations like `jit` and `pmap` stage out computations to XLA. That
is, we apply them to functions comprising multiple primitive operations so that
rather being executed one at a time from Python the operations are all part of
one end-to-end optimized XLA computation.
Expand Down
17 changes: 9 additions & 8 deletions docs/autodidax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice both `lift` and `sublift` package a value into a `JVPTracer` with the\n",
"Notice both `pure` and `lift` package a value into a `JVPTracer` with the\n",
"minimal amount of context, which is a zero tangent value.\n",
"\n",
"Let's add some JVP rules for primitives:"
Expand Down Expand Up @@ -1312,7 +1312,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Jaxpr data strutures\n",
"### Jaxpr data structures\n",
"\n",
"The jaxpr term syntax is roughly:\n",
"\n",
Expand Down Expand Up @@ -2720,7 +2720,7 @@
" g:float64[] = neg e\n",
" in ( g ) }\n",
"```\n",
"This second jaxpr is represents the linear computation that we want from\n",
"This second jaxpr represents the linear computation that we want from\n",
"`linearize`.\n",
"\n",
"However, unlike in this jaxpr example, we want the computation on known values\n",
Expand All @@ -2729,7 +2729,7 @@
"operations out of Python first before sorting out what can be evaluated now\n",
"and what must be delayed, we want only to form a jaxpr for those operations\n",
"that _must_ be delayed due to a dependence on unknown inputs. In the context\n",
"of automatic differentiation, this is the feature ultimately enables us to\n",
"of automatic differentiation, this is the feature that ultimately enables us to\n",
"handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control\n",
"flow works because partial evaluation keeps the primal computation in Python.\n",
"As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out\n",
Expand Down Expand Up @@ -2874,9 +2874,10 @@
"(evaluating it in Python) and avoid forming tracers corresponding to the\n",
"output. If instead any input is unknown then we instead stage out into a\n",
"`JaxprEqnRecipe` representing the primitive application. To build the tracers\n",
"representing unknown outputs, we need avals, which get from the abstract eval\n",
"rules. (Notice that tracers reference `JaxprEqnRecipe`s, and `JaxprEqnRecipe`s\n",
"reference tracers; we avoid circular garbage by using weakrefs.)\n",
"representing unknown outputs, we need avals, which we get from the abstract\n",
"eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and\n",
"`JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using\n",
"weakrefs.)\n",
"\n",
"That `process_primitive` logic applies to most primitives, but `xla_call_p`\n",
"requires recursive treatment. So we special-case its rule in a\n",
Expand Down Expand Up @@ -3312,7 +3313,7 @@
"metadata": {},
"source": [
"We use `UndefPrimal` instances to indicate which arguments with respect to\n",
"with we want to transpose. These arise because in general, being explicit\n",
"which we want to transpose. These arise because in general, being explicit\n",
"about closed-over values, we want to transpose functions of type\n",
"`a -> b -o c` to functions of type `a -> c -o b`. Even more generally, the\n",
"inputs with respect to which the function is linear could be scattered through\n",
Expand Down
17 changes: 9 additions & 8 deletions docs/autodidax.md
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ class JVPTrace(Trace):
jvp_rules = {}
```

Notice both `lift` and `sublift` package a value into a `JVPTracer` with the
Notice both `pure` and `lift` package a value into a `JVPTracer` with the
minimal amount of context, which is a zero tangent value.

Let's add some JVP rules for primitives:
Expand Down Expand Up @@ -960,7 +960,7 @@ jaxpr and then interpreting the jaxpr.)

+++

### Jaxpr data strutures
### Jaxpr data structures

The jaxpr term syntax is roughly:

Expand Down Expand Up @@ -2012,7 +2012,7 @@ and tangent jaxprs:
g:float64[] = neg e
in ( g ) }
```
This second jaxpr is represents the linear computation that we want from
This second jaxpr represents the linear computation that we want from
`linearize`.

However, unlike in this jaxpr example, we want the computation on known values
Expand All @@ -2021,7 +2021,7 @@ forming a jaxpr for the entire function `(a1, a2) -> (b1, b2)`, staging all
operations out of Python first before sorting out what can be evaluated now
and what must be delayed, we want only to form a jaxpr for those operations
that _must_ be delayed due to a dependence on unknown inputs. In the context
of automatic differentiation, this is the feature ultimately enables us to
of automatic differentiation, this is the feature that ultimately enables us to
handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control
flow works because partial evaluation keeps the primal computation in Python.
As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out
Expand Down Expand Up @@ -2122,9 +2122,10 @@ inputs are known then we can bind the primitive on the known values
(evaluating it in Python) and avoid forming tracers corresponding to the
output. If instead any input is unknown then we instead stage out into a
`JaxprEqnRecipe` representing the primitive application. To build the tracers
representing unknown outputs, we need avals, which get from the abstract eval
rules. (Notice that tracers reference `JaxprEqnRecipe`s, and `JaxprEqnRecipe`s
reference tracers; we avoid circular garbage by using weakrefs.)
representing unknown outputs, we need avals, which we get from the abstract
eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and
`JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using
weakrefs.)

That `process_primitive` logic applies to most primitives, but `xla_call_p`
requires recursive treatment. So we special-case its rule in a
Expand Down Expand Up @@ -2468,7 +2469,7 @@ register_pytree_node(UndefPrimal,
```

We use `UndefPrimal` instances to indicate which arguments with respect to
with we want to transpose. These arise because in general, being explicit
which we want to transpose. These arise because in general, being explicit
about closed-over values, we want to transpose functions of type
`a -> b -o c` to functions of type `a -> c -o b`. Even more generally, the
inputs with respect to which the function is linear could be scattered through
Expand Down
17 changes: 9 additions & 8 deletions docs/autodidax.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def process_primitive(self, primitive, tracers, params):
jvp_rules = {}
# -

# Notice both `lift` and `sublift` package a value into a `JVPTracer` with the
# Notice both `pure` and `lift` package a value into a `JVPTracer` with the
# minimal amount of context, which is a zero tangent value.
#
# Let's add some JVP rules for primitives:
Expand Down Expand Up @@ -919,7 +919,7 @@ def f(x):
# control flow, any transformation could be implemented by first tracing to a
# jaxpr and then interpreting the jaxpr.)

# ### Jaxpr data strutures
# ### Jaxpr data structures
#
# The jaxpr term syntax is roughly:
#
Expand Down Expand Up @@ -1930,7 +1930,7 @@ def vspace(aval: ShapedArray) -> ShapedArray:
# g:float64[] = neg e
# in ( g ) }
# ```
# This second jaxpr is represents the linear computation that we want from
# This second jaxpr represents the linear computation that we want from
# `linearize`.
#
# However, unlike in this jaxpr example, we want the computation on known values
Expand All @@ -1939,7 +1939,7 @@ def vspace(aval: ShapedArray) -> ShapedArray:
# operations out of Python first before sorting out what can be evaluated now
# and what must be delayed, we want only to form a jaxpr for those operations
# that _must_ be delayed due to a dependence on unknown inputs. In the context
# of automatic differentiation, this is the feature ultimately enables us to
# of automatic differentiation, this is the feature that ultimately enables us to
# handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control
# flow works because partial evaluation keeps the primal computation in Python.
# As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out
Expand Down Expand Up @@ -2036,9 +2036,10 @@ def full_lower(self):
# (evaluating it in Python) and avoid forming tracers corresponding to the
# output. If instead any input is unknown then we instead stage out into a
# `JaxprEqnRecipe` representing the primitive application. To build the tracers
# representing unknown outputs, we need avals, which get from the abstract eval
# rules. (Notice that tracers reference `JaxprEqnRecipe`s, and `JaxprEqnRecipe`s
# reference tracers; we avoid circular garbage by using weakrefs.)
# representing unknown outputs, we need avals, which we get from the abstract
# eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and
# `JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using
# weakrefs.)
#
# That `process_primitive` logic applies to most primitives, but `xla_call_p`
# requires recursive treatment. So we special-case its rule in a
Expand Down Expand Up @@ -2376,7 +2377,7 @@ class UndefPrimal(NamedTuple):
# -

# We use `UndefPrimal` instances to indicate which arguments with respect to
# with we want to transpose. These arise because in general, being explicit
# which we want to transpose. These arise because in general, being explicit
# about closed-over values, we want to transpose functions of type
# `a -> b -o c` to functions of type `a -> c -o b`. Even more generally, the
# inputs with respect to which the function is linear could be scattered through
Expand Down
2 changes: 1 addition & 1 deletion docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,6 @@ fix the issues you can push new commits to your branch.
Once your PR has been reviewed, a JAX maintainer will mark it as `Pull Ready`. This
will trigger a larger set of tests, including tests on GPU and TPU backends that are
not available via standard GitHub CI. Detailed results of these tests are not publicly
viweable, but the JAX mantainer assigned to your PR will communicate with you regarding
viewable, but the JAX maintainer assigned to your PR will communicate with you regarding
any failures these might uncover; it's not uncommon, for example, that numerical tests
need different tolerances on TPU than on CPU.
2 changes: 1 addition & 1 deletion docs/custom_vjp_update.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ skip_app.defvjp(skip_app_fwd, skip_app_bwd)
## Explanation

Passing `Tracer`s into `nondiff_argnums` arguments was always buggy. While there
were some cases which worked correctly, others would lead to complex and
were some cases that worked correctly, others would lead to complex and
confusing error messages.

The essence of the bug was that `nondiff_argnums` was implemented in a way that
Expand Down
6 changes: 3 additions & 3 deletions docs/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ You can either install Python using its
[Windows installer](https://www.python.org/downloads/), or if you prefer, you
can use [Anaconda](https://docs.anaconda.com/anaconda/install/windows/)
or [Miniconda](https://docs.conda.io/en/latest/miniconda.html#windows-installers)
to setup a Python environment.
to set up a Python environment.

Some targets of Bazel use bash utilities to do scripting, so [MSYS2](https://www.msys2.org)
is needed. See [Installing Bazel on Windows](https://docs.bazel.build/versions/master/install-windows.html#installing-compilers-and-language-runtimes)
Expand Down Expand Up @@ -174,7 +174,7 @@ python tests/lax_numpy_test.py --test_targets="testPad"

The Colab notebooks are tested for errors as part of the documentation build.

Note that to run the full pmap tests on a (multi-core) CPU only machine, you
Note that to run the full pmap tests on a (multi-core) CPU-only machine, you
can run:

```
Expand Down Expand Up @@ -278,7 +278,7 @@ See `exclude_patterns` in [conf.py](https://github.com/google/jax/blob/main/docs

## Documentation building on readthedocs.io

JAX's auto-generated documentations is at <https://jax.readthedocs.io/>.
JAX's auto-generated documentation is at <https://jax.readthedocs.io/>.

The documentation building is controlled for the entire project by the
[readthedocs JAX settings](https://readthedocs.org/dashboard/jax). The current settings
Expand Down
4 changes: 2 additions & 2 deletions docs/device_memory_profiling.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ For more information about how to interpret callgraph visualizations, see the

Functions compiled with {func}`jax.jit` are opaque to the device memory profiler.
That is, any memory allocated inside a `jit`-compiled function will be
attributed to the function as whole.
attributed to the function as a whole.

In the example, the call to `block_until_ready()` is to ensure that `func2`
completes before the device memory profile is collected. See
Expand All @@ -90,7 +90,7 @@ completes before the device memory profile is collected. See

We can also use the JAX device memory profiler to track down memory leaks by using
`pprof` to visualize the change in memory usage between two device memory profiles
taken at different times. For example consider the following program which
taken at different times. For example, consider the following program which
accumulates JAX arrays into a constantly-growing Python list.

```python
Expand Down
10 changes: 5 additions & 5 deletions docs/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ If your ``jit`` decorated function takes tens of seconds (or more!) to run the
first time you call it, but executes quickly when called again, JAX is taking a
long time to trace or compile your code.

This is usually a symptom of calling your function generating a large amount of
This is usually a sign that calling your function generates a large amount of
code in JAX's internal representation, typically because it makes heavy use of
Python control flow such as ``for`` loop. For a handful of loop iterations
Python is OK, but if you need _many_ loop iterations, you should rewrite your
Python control flow such as ``for`` loops. For a handful of loop iterations,
Python is OK, but if you need *many* loop iterations, you should rewrite your
code to make use of JAX's
`structured control flow primitives <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Structured-control-flow-primitives>`_
(such as :func:`lax.scan`) or avoid wrapping the loop with ``jit`` (you can
Expand Down Expand Up @@ -206,7 +206,7 @@ running full applications, which inevitably include some amount of both data
transfer and compilation. Also, we were careful to pick large enough arrays
(1000x1000) and an intensive enough computation (the ``@`` operator is
performing matrix-matrix multiplication) to amortize the increased overhead of
JAX/accelerators vs NumPy/CPU. For example, if switch this example to use
JAX/accelerators vs NumPy/CPU. For example, if we switch this example to use
10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs).

.. _To JIT or not to JIT: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit
Expand Down Expand Up @@ -322,7 +322,7 @@ are not careful you may obtain a ``NaN`` for reverse differentiation::
jax.grad(my_log)(0.) ==> NaN

A short explanation is that during ``grad`` computation the adjoint corresponding
to the undefined ``jnp.log(x)`` is a ``NaN`` and when it gets accumulated to the
to the undefined ``jnp.log(x)`` is a ``NaN`` and it gets accumulated to the
adjoint of the ``jnp.where``. The correct way to write such functions is to ensure
that there is a ``jnp.where`` *inside* the partially-defined function, to ensure
that the adjoint is always finite::
Expand Down
2 changes: 1 addition & 1 deletion docs/jax-101/01-jax-basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@
"\n",
"(Like $\\nabla$, `jax.grad` will only work on functions with a scalar output -- it will raise an error otherwise.)\n",
"\n",
"This makes the JAX API quite different to other autodiff libraries like Tensorflow and PyTorch, where to compute the gradient we use the loss tensor itself (e.g. by calling `loss.backward()`). The JAX API works directly with functions, staying closer to the underlying math. Once you become accustomed to this way of doing things, it feels natural: your loss function in code really is a function of parameters and data, and you find its gradient just like you would in the math.\n",
"This makes the JAX API quite different from other autodiff libraries like Tensorflow and PyTorch, where to compute the gradient we use the loss tensor itself (e.g. by calling `loss.backward()`). The JAX API works directly with functions, staying closer to the underlying math. Once you become accustomed to this way of doing things, it feels natural: your loss function in code really is a function of parameters and data, and you find its gradient just like you would in the math.\n",
"\n",
"This way of doing things makes it straightforward to control things like which variables to differentiate with respect to. By default, `jax.grad` will find the gradient with respect to the first argument. In the example below, the result of `sum_squared_error_dx` will be the gradient of `sum_squared_error` with respect to `x`."
]
Expand Down
Loading

0 comments on commit 7392a57

Please sign in to comment.