Skip to content

Commit

Permalink
Test code in docs and api.py docstrings (jax-ml#2994)
Browse files Browse the repository at this point in the history
Also remove jaxpr doc tests from api_test.py.
  • Loading branch information
Jamie Townsend authored May 16, 2020
1 parent 510af1d commit 670fab5
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 247 deletions.
4 changes: 3 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
- python: "3.6"
env: JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=25
- python: "3.7"
env: JAX_ENABLE_X64=1 JAX_ONLY_DOCUMENTATION=true
env: JAX_ENABLE_X64=0 JAX_ONLY_DOCUMENTATION=true

before_install:
- wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
Expand Down Expand Up @@ -44,6 +44,8 @@ install:
script:
- if [ "$JAX_ONLY_DOCUMENTATION" = true ]; then
sphinx-build -b html -D nbsphinx_execute=always docs docs/build/html ;
pytest docs ;
pytest --doctest-modules jax/api.py ;
elif [ "$JAX_ONLY_CHECK_TYPES" = true ]; then
echo "===== Checking with mypy ====" &&
time mypy --config-file=mypy.ini jax ;
Expand Down
10 changes: 5 additions & 5 deletions docs/async_dispatch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ program:
>>> from jax import numpy as np
>>> from jax import random
>>> x = random.uniform(random.PRNGKey(0), (1000, 1000))
>>> np.dot(x, x) + 3.
>>> np.dot(x, x) + 3. # doctest: +SKIP
DeviceArray([[258.01971436, 249.64862061, 257.13372803, ...,
236.67948914, 250.68939209, 241.36853027],
[265.65979004, 256.28912354, 262.18252563, ...,
Expand Down Expand Up @@ -44,7 +44,7 @@ arbitrary amounts of work and avoid having the accelerator wait.

Asynchronous dispatch has a slightly surprising consequence for microbenchmarks.

>>> %time np.dot(x, x)
>>> %time np.dot(x, x) # doctest: +SKIP
CPU times: user 267 µs, sys: 93 µs, total: 360 µs
Wall time: 269 µs
DeviceArray([[255.01972961, 246.64862061, 254.13371277, ...,
Expand All @@ -70,7 +70,7 @@ use the :meth:`~jaxDeviceArray.block_until_ready` method on a
:class:`DeviceArray` value to wait for the computation that produced it to
complete.

>>> %time onp.asarray(np.dot(x, x))
>>> %time onp.asarray(np.dot(x, x)) # doctest: +SKIP
CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms
Wall time: 8.09 ms
Out[16]:
Expand All @@ -87,7 +87,7 @@ array([[255.01973, 246.64862, 254.13371, ..., 233.67949, 247.68939,
258.337 ],
[254.16135, 251.75433, 256.083 , ..., 238.59848, 245.62598,
240.22348]], dtype=float32)
>>> %time np.dot(x, x).block_until_ready()
>>> %time np.dot(x, x).block_until_ready() # doctest: +SKIP
CPU times: user 50.3 ms, sys: 928 µs, total: 51.2 ms
Wall time: 4.92 ms
DeviceArray([[255.01972961, 246.64862061, 254.13371277, ...,
Expand All @@ -105,4 +105,4 @@ DeviceArray([[255.01972961, 246.64862061, 254.13371277, ...,
245.62597656, 240.22348022]], dtype=float32)

Blocking without transferring the result back to Python is usually faster, and
is often the best choice when writing microbenchmarks of computation times.
is often the best choice when writing microbenchmarks of computation times.
5 changes: 3 additions & 2 deletions docs/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ By default, JAX arrays are placed uncommitted on the default device
(``jax.devices()[0]``).

>>> from jax import numpy as jnp
>>> print(jnp.ones(3).device_buffer.device())
>>> print(jnp.ones(3).device_buffer.device()) # doctest: +SKIP
gpu:0

Computations involving uncommitted data are performed on the default
Expand All @@ -97,8 +97,9 @@ device and the results are uncommitted on the default device.
Data can also be placed explicitly on a device using :func:`jax.device_put`
with a ``device`` parameter, in which case if becomes **committed** to the device:

>>> import jax
>>> from jax import device_put
>>> print(device_put(1, jax.devices()[2]).device_buffer.device())
>>> print(device_put(1, jax.devices()[2]).device_buffer.device()) # doctest: +SKIP
gpu:2

Computations involving some committed inputs, will happen on the
Expand Down
Loading

0 comments on commit 670fab5

Please sign in to comment.