Skip to content

Commit

Permalink
Prepare for jax and jaxlib 0.4.0 release
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 493733609
  • Loading branch information
yashk2810 authored and jax authors committed Dec 8, 2022
1 parent dd64760 commit 0118f8d
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 16 deletions.
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -427,15 +427,15 @@ learning systems, JAX does not bundle CUDA or CuDNN as part of the `pip`
package.

JAX provides pre-built CUDA-compatible wheels for **Linux only**,
with CUDA 11.1 or newer, and CuDNN 8.0.5 or newer. Note these existing wheels are currently for `x86_64` architectures only. Other combinations of
with CUDA 11.4 or newer, and CuDNN 8.2 or newer. Note these existing wheels are currently for `x86_64` architectures only. Other combinations of
operating system, CUDA, and CuDNN are possible, but require [building from
source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).

* CUDA 11.1 or newer is *required*.
* CUDA 11.4 or newer is *required*.
* The supported cuDNN versions for the prebuilt wheels are:
* cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN
* cuDNN 8.6 or newer. We recommend using the cuDNN 8.6 wheel if your cuDNN
installation is new enough, since it supports additional functionality.
* cuDNN 8.0.5 or newer.
* cuDNN 8.2 or newer.
* You *must* use an NVidia driver version that is at least as new as your
[CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions).
For example, if you have CUDA 11.4 update 4 installed, you must use NVidia
Expand All @@ -453,7 +453,7 @@ Next, run

```bash
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
Expand All @@ -468,11 +468,11 @@ version for jaxlib explicitly:
```bash
pip install --upgrade pip

# Installs the wheel compatible with Cuda >= 11.8 and cudnn >= 8.6
pip install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5
pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

You can find your CUDA version with the command:
Expand All @@ -483,7 +483,7 @@ nvcc --version

Some GPU functionality expects the CUDA installation to be at
`/usr/local/cuda-X.X`, where X.X should be replaced with the CUDA version number
(e.g. `cuda-11.1`). If CUDA is installed elsewhere on your system, you can either
(e.g. `cuda-11.8`). If CUDA is installed elsewhere on your system, you can either
create a symlink:

```bash
Expand Down
6 changes: 3 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
# and update the sha256 with the result.
http_archive(
name = "org_tensorflow",
sha256 = "bfd40279b247d2d0b0dc5c5a776b595c9d4979889dcf0529c85fe9f6ff7a5255",
strip_prefix = "tensorflow-c21f137bc42450f10f7d04f9d263852827afd079",
sha256 = "47edef97c9b23661fd63621d522454f30772ac70a1fb5ff82864e566ef86be78",
strip_prefix = "tensorflow-f3cc513887e06150b6f870c522220dabedc58920",
urls = [
"https://github.com/tensorflow/tensorflow/archive/c21f137bc42450f10f7d04f9d263852827afd079.tar.gz",
"https://github.com/tensorflow/tensorflow/archive/f3cc513887e06150b6f870c522220dabedc58920.tar.gz",
],
)

Expand Down
2 changes: 1 addition & 1 deletion jax/tools/colab_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
TPU_DRIVER_MODE = 0


def setup_tpu(tpu_driver_version='tpu_driver_20221109'):
def setup_tpu(tpu_driver_version='tpu_driver_20221207'):
"""Sets up Colab to run on TPU.
Note: make sure the Colab Runtime is set to Accelerator: TPU.
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@

from setuptools import setup, find_packages

_current_jaxlib_version = '0.3.25'
_current_jaxlib_version = '0.4.0'
# The following should be updated with each new jaxlib release.
_latest_jaxlib_version_on_pypi = '0.3.25'
_available_cuda_versions = ['11']
_default_cuda_version = '11'
_available_cudnn_versions = ['82', '86']
_default_cudnn_version = '86'
_libtpu_version = '0.1.dev20221109'
_libtpu_version = '0.1.dev20221207'

_dct = {}
with open('jax/version.py') as f:
Expand Down Expand Up @@ -96,7 +96,7 @@ def generate_proto(source):

# CUDA installations require adding jax releases URL; e.g.
# $ pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# $ pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# $ pip install jax[cuda11_cudnn86] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
**{f'cuda{cuda_version}_cudnn{cudnn_version}': f"jaxlib=={_current_jaxlib_version}+cuda{cuda_version}.cudnn{cudnn_version}"
for cuda_version in _available_cuda_versions for cudnn_version in _available_cudnn_versions}
},
Expand Down

0 comments on commit 0118f8d

Please sign in to comment.