Skip to content

Commit

Permalink
Remove support for CUDA 11.
Browse files Browse the repository at this point in the history
Pin minimal required versions for CUDA to 12.1.

PiperOrigin-RevId: 618195554
  • Loading branch information
jax authors committed Mar 22, 2024
1 parent c82deb2 commit bed4f65
Show file tree
Hide file tree
Showing 14 changed files with 194 additions and 161 deletions.
24 changes: 0 additions & 24 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -228,30 +228,6 @@ build:rbe_linux_cuda_base --config=rbe_linux
build:rbe_linux_cuda_base --config=cuda
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1

build:rbe_linux_cuda11.8_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda11.8_nvcc_base --config=cuda_clang
build:rbe_linux_cuda11.8_nvcc_base --action_env=TF_NVCC_CLANG="1"
build:rbe_linux_cuda11.8_nvcc_base --action_env=TF_CUDA_VERSION=11
build:rbe_linux_cuda11.8_nvcc_base --action_env=TF_CUDNN_VERSION=8
build:rbe_linux_cuda11.8_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.8"
build:rbe_linux_cuda11.8_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
build:rbe_linux_cuda11.8_nvcc_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.8_nvcc_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.8_nvcc_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda11.8_nvcc_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_platform//:platform"
build:rbe_linux_cuda11.8_nvcc_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_platform//:platform"
build:rbe_linux_cuda11.8_nvcc_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_platform//:platform"
build:rbe_linux_cuda11.8_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_cuda"
build:rbe_linux_cuda11.8_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_nccl"
build:rbe_linux_cuda11.8_nvcc_py3.9 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.9"
build:rbe_linux_cuda11.8_nvcc_py3.9 --python_path="/usr/local/bin/python3.9"
build:rbe_linux_cuda11.8_nvcc_py3.10 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.10"
build:rbe_linux_cuda11.8_nvcc_py3.10 --python_path="/usr/local/bin/python3.10"
build:rbe_linux_cuda11.8_nvcc_py3.11 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.11"
build:rbe_linux_cuda11.8_nvcc_py3.11 --python_path="/usr/local/bin/python3.11"
build:rbe_linux_cuda11.8_nvcc_py3.12 --config=rbe_linux_cuda11.8_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda11.8-cudnn8.6-tensorrt8.4_config_python3.12"
build:rbe_linux_cuda11.8_nvcc_py3.12 --python_path="/usr/local/bin/python3.12"

build:rbe_linux_cuda12.3_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda12.3_nvcc_base --config=cuda_clang
build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_NVCC_CLANG="1"
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ Remember to align the itemized text with the first line of an item within a list

## jaxlib 0.4.26

* Changes
* JAX now supports CUDA 12.1 or newer only. Support for CUDA 11.8 has been
dropped.

## jax 0.4.25 (Feb 26, 2024)

* New Features
Expand Down
29 changes: 11 additions & 18 deletions docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ NVIDIA has dropped support for Kepler GPUs in its software.

You must first install the NVIDIA driver. We
recommend installing the newest driver available from NVIDIA, but the driver
must be version >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux.
must be version >= 525.60.13 for CUDA 12 on Linux.
If you need to use a newer CUDA toolkit with an older driver, for example
on a cluster where you cannot update the NVIDIA driver easily, you may be
able to use the
Expand All @@ -82,10 +82,6 @@ pip install --upgrade pip
# CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# CUDA 11 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

If JAX detects the wrong version of the CUDA libraries, there are several things
Expand Down Expand Up @@ -113,14 +109,19 @@ able to use the
[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/)
that NVIDIA provides for this purpose.

JAX currently ships two CUDA wheel variants:
* CUDA 12.3, cuDNN 8.9, NCCL 2.16
* CUDA 11.8, cuDNN 8.6, NCCL 2.16
JAX currently ships one CUDA wheel variant:

| Built with | Compatible with |
|------------|-----------------|
| CUDA 12.3 | CUDA 12.1+ |
| cuDNN 8.9 | cuDNN 8.9+ |
| NCCL 2.19 | NCCL 2.18+ |

You may use a JAX wheel provided the major version of your CUDA, cuDNN, and NCCL
installations match, and the minor versions are the same or newer.
JAX checks the versions of your libraries, and will report an error if they are
not sufficiently new.
Setting the `JAX_SKIP_CUDA_CONSTRAINTS_CHECK` environment variable will disable
the check, but using older versions of CUDA may lead to errors, or incorrect
results.

NCCL is an optional dependency, required only if you are performing multi-GPU
computations.
Expand All @@ -134,9 +135,6 @@ pip install --upgrade pip
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

**These `pip` installations do not work with Windows, and may fail silently; see
Expand Down Expand Up @@ -188,11 +186,6 @@ pip install -U --pre libtpu-nightly -f https://storage.googleapis.com/jax-releas
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
```

* Jaxlib GPU (Cuda 11):
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html
```

## Google TPU

### pip installation: Google Cloud TPU
Expand Down
32 changes: 11 additions & 21 deletions docs/tutorials/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ NVIDIA has dropped support for Kepler GPUs in its software.

You must first install the NVIDIA driver. You're
recommended to install the newest driver available from NVIDIA, but the driver
version must be >= 525.60.13 for CUDA 12 and >= 450.80.02 for CUDA 11 on Linux.
version must be >= 525.60.13 for CUDA 12 on Linux.

If you need to use a newer CUDA toolkit with an older driver, for example
on a cluster where you cannot update the NVIDIA driver easily, you may be
Expand All @@ -99,10 +99,6 @@ pip install --upgrade pip
# NVIDIA CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# NVIDIA CUDA 11 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

If JAX detects the wrong version of the NVIDIA CUDA libraries, there are several things
Expand Down Expand Up @@ -131,15 +127,19 @@ able to use the
[CUDA forward compatibility packages](https://docs.nvidia.com/deploy/cuda-compatibility/)
that NVIDIA provides for this purpose.

JAX currently ships two NVIDIA CUDA wheel variants:
JAX currently ships one CUDA wheel variant:

- CUDA 12.2, cuDNN 8.9, NCCL 2.16
- CUDA 11.8, cuDNN 8.6, NCCL 2.16
| Built with | Compatible with |
|------------|-----------------|
| CUDA 12.3 | CUDA 12.1+ |
| cuDNN 8.9 | cuDNN 8.9+ |
| NCCL 2.19 | NCCL 2.18+ |

You may use a JAX wheel provided the major version of your CUDA, cuDNN, and NCCL
installations match, and the minor versions are the same or newer.
JAX checks the versions of your libraries, and will report an error if they are
not sufficiently new.
Setting the `JAX_SKIP_CUDA_CONSTRAINTS_CHECK` environment variable will disable
the check, but using older versions of CUDA may lead to errors, or incorrect
results.

NCCL is an optional dependency, required only if you are performing multi-GPU
computations.
Expand All @@ -152,10 +152,6 @@ pip install --upgrade pip
# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 8.9 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Installs the wheel compatible with NVIDIA CUDA 11 and cuDNN 8.6 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

**These `pip` installations do not work with Windows, and may fail silently; refer to the table
Expand Down Expand Up @@ -212,12 +208,6 @@ pip install -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/lib
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
```

- `jaxlib` NVIDIA GPU (CUDA 11):

```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html
```

(install-google-tpu)=
## Google Cloud TPU

Expand Down Expand Up @@ -318,4 +308,4 @@ pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_re
For specific older GPU wheels, be sure to use the `jax_cuda_releases.html` URL; for example
```bash
pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
```
129 changes: 112 additions & 17 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import os
import pkgutil
import platform as py_platform
import traceback
import sys
import threading
from typing import Any, Callable, Union
Expand Down Expand Up @@ -267,33 +268,101 @@ def _check_cuda_compute_capability(devices_to_check):
RuntimeWarning
)

def _check_cuda_versions():

def _check_cuda_versions(raise_on_first_error: bool = False,
debug: bool = False):
assert cuda_versions is not None
results: list[dict[str, Any]] = []

def _make_msg(name: str,
runtime_version: int,
build_version: int,
min_supported: int,
debug_msg: bool = False):
if debug_msg:
return (f"Package: {name}\n"
f"Version JAX was built against: {build_version}\n"
f"Minimum supported: {min_supported}\n"
f"Installed version: {runtime_version}")
if min_supported:
req_str = (f"The local installation version must be no lower than "
f"{min_supported}.")
else:
req_str = ("The local installation must be the same version as "
"the version against which JAX was built.")
msg = (f"Outdated {name} installation found.\n"
f"Version JAX was built against: {build_version}\n"
f"Minimum supported: {min_supported}\n"
f"Installed version: {runtime_version}\n"
f"{req_str}")
return msg

def _version_check(name: str,
get_version,
get_build_version,
scale_for_comparison: int = 1,
min_supported_version: int = 0):
"""Checks the runtime CUDA component version against the JAX one.
Args:
name: Of the CUDA component.
get_version: A function to get the local runtime version of the component.
get_build_version: A function to get the build version of the component.
scale_for_comparison: For rounding down a version to ignore patch/minor.
min_supported_version: An absolute minimum version required. Must be
passed without rounding down.
Raises:
RuntimeError: If the component is not found, or is of unsupported version,
and if raising the error is not deferred till later.
"""

def _version_check(name, get_version, get_build_version,
scale_for_comparison=1):
build_version = get_build_version()
try:
version = get_version()
except Exception as e:
raise RuntimeError(f"Unable to load {name}. Is it installed?") from e
if build_version // scale_for_comparison > version // scale_for_comparison:
raise RuntimeError(
f"Found {name} version {version}, but JAX was built against version "
f"{build_version}, which is newer. The copy of {name} that is "
"installed must be at least as new as the version against which JAX "
"was built."
)
err_msg = f"Unable to load {name}. Is it installed?"
if raise_on_first_error:
raise RuntimeError(err_msg) from e
err_msg += f"\n{traceback.format_exc()}"
results.append({"name": name, "installed": False, "msg": err_msg})
return

if not min_supported_version:
min_supported_version = build_version // scale_for_comparison
passed = min_supported_version <= version

if not passed or debug:
msg = _make_msg(name=name,
runtime_version=version,
build_version=build_version,
min_supported=min_supported_version,
debug_msg=passed)
if not passed and raise_on_first_error:
raise RuntimeError(msg)
else:
record = {"name": name,
"installed": True,
"msg": msg,
"passed": passed,
"build_version": build_version,
"version": version,
"minimum_supported": min_supported_version}
results.append(record)

_version_check("CUDA", cuda_versions.cuda_runtime_get_version,
cuda_versions.cuda_runtime_build_version)
cuda_versions.cuda_runtime_build_version,
scale_for_comparison=10,
min_supported_version=12010)
_version_check(
"cuDNN",
cuda_versions.cudnn_get_version,
cuda_versions.cudnn_build_version,
# NVIDIA promise both backwards and forwards compatibility for cuDNN patch
# versions: https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat
# versions:
# https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat
scale_for_comparison=100,
min_supported_version=8900
)
_version_check("cuFFT", cuda_versions.cufft_get_version,
cuda_versions.cufft_build_version,
Expand All @@ -302,20 +371,42 @@ def _version_check(name, get_version, get_build_version,
_version_check("cuSOLVER", cuda_versions.cusolver_get_version,
cuda_versions.cusolver_build_version,
# Ignore patch versions.
scale_for_comparison=100)
scale_for_comparison=100,
min_supported_version=11400)
_version_check("cuPTI", cuda_versions.cupti_get_version,
cuda_versions.cupti_build_version)
cuda_versions.cupti_build_version,
min_supported_version=18)
# TODO(jakevdp) remove these checks when minimum jaxlib is v0.4.21
if hasattr(cuda_versions, "cublas_get_version"):
_version_check("cuBLAS", cuda_versions.cublas_get_version,
cuda_versions.cublas_build_version,
# Ignore patch versions.
scale_for_comparison=100)
scale_for_comparison=100,
min_supported_version=120100)
if hasattr(cuda_versions, "cusparse_get_version"):
_version_check("cuSPARSE", cuda_versions.cusparse_get_version,
cuda_versions.cusparse_build_version,
# Ignore patch versions.
scale_for_comparison=100)
scale_for_comparison=100,
min_supported_version=12100)

errors = []
debug_results = []
for result in results:
message: str = result['msg']
if not result['installed'] or not result['passed']:
errors.append(message)
else:
debug_results.append(message)

join_str = f'\n{"-" * 50}\n'
if debug_results:
print(f'CUDA components status (debug):\n'
f'{join_str.join(debug_results)}')
if errors:
raise RuntimeError(f'Unable to use CUDA because of the '
f'following issues with CUDA components:\n'
f'{join_str.join(errors)}')


def make_gpu_client(
Expand All @@ -335,6 +426,10 @@ def make_gpu_client(
if platform_name == "cuda":
if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"):
_check_cuda_versions()
else:
print('Skipped CUDA versions constraints check due to the '
'JAX_SKIP_CUDA_CONSTRAINTS_CHECK env var being set.')

# TODO(micky774): remove this check when minimum jaxlib is v0.4.26
if jaxlib.version.__version_info__ >= (0, 4, 26):
devices_to_check = (allowed_devices if allowed_devices else
Expand Down
2 changes: 1 addition & 1 deletion jax_plugins/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

# cuda_plugin_extension locates inside jaxlib. `jaxlib` is for testing without
# preinstalled jax cuda plugin packages.
for pkg_name in ['jax_cuda12_plugin', 'jax_cuda11_plugin', 'jaxlib']:
for pkg_name in ['jax_cuda12_plugin', 'jaxlib']:
try:
cuda_plugin_extension = importlib.import_module(
f'{pkg_name}.cuda_plugin_extension'
Expand Down
2 changes: 1 addition & 1 deletion jaxlib/gpu_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from jaxlib import xla_client

for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
_cuda_linalg = importlib.import_module(
f"{cuda_module_name}._linalg", package="jaxlib"
Expand Down
2 changes: 1 addition & 1 deletion jaxlib/gpu_prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .hlo_helpers import custom_call
from .gpu_common_utils import GpuLibNotLinkedError

for cuda_module_name in [".cuda", "jax_cuda12_plugin", "jax_cuda11_plugin"]:
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
_cuda_prng = importlib.import_module(
f"{cuda_module_name}._prng", package="jaxlib"
Expand Down
Loading

0 comments on commit bed4f65

Please sign in to comment.