Skip to content

Commit

Permalink
mention cloud tpus in readme
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Dec 14, 2019
1 parent 764f007 commit 5c80036
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,17 @@ executed. But JAX also lets you just-in-time compile your own Python functions
into XLA-optimized kernels using a one-function API,
[`jit`](#compilation-with-jit). Compilation and automatic differentiation can be
composed arbitrarily, so you can express sophisticated algorithms and get
maximal performance without leaving Python.
maximal performance without leaving Python. You can even program multiple GPUs
or TPU cores at once using [`pmap`](#spmd-programming-with-pmap), and
differentiate through the whole thing.

Dig a little deeper, and you'll see that JAX is really an extensible system for
[composable function transformations](#transformations). Both
[`grad`](#automatic-differentiation-with-grad) and [`jit`](#compilation-with-jit)
are instances of such transformations. Others are
[`vmap`](#auto-vectorization-with-vmap) for automatic vectorization and
[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD)
parallel programming, with more to come.
parallel programming of multiple accelerators, with more to come.

This is a research project, not an official Google product. Expect bugs and
[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
Expand Down Expand Up @@ -72,11 +74,15 @@ perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0))) # fast per-example grad
* [Reference documentation](#reference-documentation)

## Quickstart: Colab in the Cloud
Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks:
Jump right in using a notebook in your browser, connected to a Google Cloud GPU.
Here are some starter notebooks:
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)
- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb)

And for a deeper dive into JAX:
**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU
Colabs](https://github.com/google/jax/tree/master/cloud_tpu_colabs).

For a deeper dive into JAX:
- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
- See the [full list of
Expand Down
2 changes: 1 addition & 1 deletion cloud_tpu_colabs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Solve the wave equation with `pmap`, and make cool movies! The spatial domain is
![](https://raw.githubusercontent.com/google/jax/master/cloud_tpu_colabs/images/wave_movie.gif)

### [JAX Demo](https://colab.research.google.com/github/google/jax/blob/master/cloud_tpu_colabs/NeurIPS_2019_JAX_demo.ipynb)
An overview of JAX presented at the Program Transformations for ML workshop at NeurIPS 2019. Covers basic numpy usage, grad, jit, vmap, and pmap.
An overview of JAX presented at the [Program Transformations for ML workshop at NeurIPS 2019](https://program-transformations.github.io/). Covers basic numpy usage, `grad`, `jit`, `vmap`, and `pmap`.

## Performance notes

Expand Down

0 comments on commit 5c80036

Please sign in to comment.