Skip to content

Commit

Permalink
Merge branch 'main' into workspace
Browse files Browse the repository at this point in the history
  • Loading branch information
yashk2810 authored Sep 1, 2021
2 parents 67bea57 + 64631c0 commit b317f6a
Show file tree
Hide file tree
Showing 88 changed files with 2,479 additions and 1,136 deletions.
107 changes: 62 additions & 45 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ build --announce_rc
build --define open_source_build=true

build --spawn_strategy=standalone
build --strategy=Genrule=standalone

build --enable_platform_specific_config

# Disable enabled-by-default TensorFlow features that we don't care about.
Expand Down Expand Up @@ -46,7 +46,7 @@ build:native_arch_posix --host_copt=-march=native
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1

build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --action_env=TF_CUDA_COMPUTE_CAPABILITIES="3.5,5.2,6.0,6.1,7.0"
build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="3.5,5.2,6.0,6.1,7.0"
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda --@local_config_cuda//:enable_cuda
build:cuda --define=xla_python_enable_gpu=true
Expand All @@ -55,7 +55,7 @@ build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --define=xla_python_enable_gpu=true
build:rocm --repo_env TF_NEED_ROCM=1
build:rocm --action_env=TF_ROCM_AMDGPU_TARGETS="gfx803,gfx900,gfx906,gfx1010"
build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908"

build:nonccl --define=no_nccl_support=true

Expand Down Expand Up @@ -133,52 +133,69 @@ build:rbe_linux --host_linkopt=-lm

# Use the GPU toolchain until the CPU one is ready.
# https://github.com/bazelbuild/bazel/issues/13623
build:rbe_cpu_linux --config=rbe_linux
build:rbe_cpu_linux --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain"
build:rbe_cpu_linux --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain"
build:rbe_cpu_linux --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_cpu_linux --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform"
build:rbe_cpu_linux --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform"
build:rbe_cpu_linux --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform"
build:rbe_cpu_linux_base --config=rbe_linux
build:rbe_cpu_linux_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain"
build:rbe_cpu_linux_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain"
build:rbe_cpu_linux_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_cpu_linux_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform"
build:rbe_cpu_linux_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform"
build:rbe_cpu_linux_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform"

build:rbe_cpu_linux_py37 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.7"
build:rbe_cpu_linux_py37 --python_path="/usr/local/bin/python3.7"
build:rbe_cpu_linux_py38 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.8"
build:rbe_cpu_linux_py38 --python_path="/usr/local/bin/python3.8"
build:rbe_cpu_linux_py39 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.9"
build:rbe_cpu_linux_py39 --python_path="/usr/local/bin/python3.9"

build:rbe_linux_cuda_base --config=rbe_linux
build:rbe_linux_cuda_base --config=cuda
build:rbe_linux_cuda_base --action_env=TF_CUDA_VERSION=11
build:rbe_linux_cuda_base --action_env=TF_CUDNN_VERSION=8
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
build:rbe_linux_cuda_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.2"
build:rbe_linux_cuda_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
build:rbe_linux_cuda_base --action_env=GCC_HOST_COMPILER_PATH="/dt7/usr/bin/gcc"

# TensorRT 7 for CUDA 11.1 is compatible with CUDA 11.2, but requires
# libnvrtc.so.11.1. See https://github.com/NVIDIA/TensorRT/issues/1064.
# TODO(b/187962120): Remove when upgrading to TensorRT 8.
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64"

build:rbe_linux_cuda11.2_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda11.2_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.2_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.2_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda11.2_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform"
build:rbe_linux_cuda11.2_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform"
build:rbe_linux_cuda11.2_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform"
build:rbe_linux_cuda11.2_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda"
build:rbe_linux_cuda11.2_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_tensorrt"
build:rbe_linux_cuda11.2_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_nccl"
build:rbe_linux_cuda11.2_nvcc_py3.6 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.6"
build:rbe_linux_cuda11.2_nvcc_py3.7 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.7"
build:rbe_linux_cuda11.2_nvcc_py3.8 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.8"
build:rbe_linux_cuda11.2_nvcc_py3.9 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.9"

# Map default to CUDA 11.2.
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda11.2_nvcc_py3.6
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda11.2_nvcc_py3.7
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda11.2_nvcc_py3.8
build:rbe_linux_cuda_nvcc_py39 --config=rbe_linux_cuda11.2_nvcc_py3.9

build:rbe_linux_py3 --config=rbe_linux
build:rbe_linux_py3 --python_path="/usr/local/bin/python3.9"
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.9"

build:rbe_linux_cuda11.1_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda11.1_nvcc_base --action_env=TF_CUDA_VERSION=11
build:rbe_linux_cuda11.1_nvcc_base --action_env=TF_CUDNN_VERSION=8
build:rbe_linux_cuda11.1_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.1"
build:rbe_linux_cuda11.1_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
build:rbe_linux_cuda11.1_nvcc_base --action_env=GCC_HOST_COMPILER_PATH="/dt7/usr/bin/gcc"
test:rbe_linux_cuda11.1_nvcc_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64"
build:rbe_linux_cuda11.1_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.1-cudnn8-tensorrt7.2_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.1-cudnn8-tensorrt7.2_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.1-cudnn8-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda11.1_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.1-cudnn8-tensorrt7.2_config_platform//:platform"
build:rbe_linux_cuda11.1_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda11.1-cudnn8-tensorrt7.2_config_platform//:platform"
build:rbe_linux_cuda11.1_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.1-cudnn8-tensorrt7.2_config_platform//:platform"
build:rbe_linux_cuda11.1_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.1-cudnn8-tensorrt7.2_config_cuda"
build:rbe_linux_cuda11.1_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.1-cudnn8-tensorrt7.2_config_tensorrt"
build:rbe_linux_cuda11.1_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.1-cudnn8-tensorrt7.2_config_nccl"
build:rbe_linux_cuda11.1_nvcc_py3.7 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.1-cudnn8-tensorrt7.2_config_python3.7"
build:rbe_linux_cuda11.1_nvcc_py3.7 --python_path="/usr/local/bin/python3.7"
build:rbe_linux_cuda11.1_nvcc_py3.8 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.1-cudnn8-tensorrt7.2_config_python3.8"
build:rbe_linux_cuda11.1_nvcc_py3.8 --python_path="/usr/local/bin/python3.8"
build:rbe_linux_cuda11.1_nvcc_py3.9 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.1-cudnn8-tensorrt7.2_config_python3.9"
build:rbe_linux_cuda11.1_nvcc_py3.9 --python_path="/usr/local/bin/python3.9"

build:rbe_linux_cuda10.2_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda10.2_nvcc_base --action_env=TF_CUDA_VERSION=10
build:rbe_linux_cuda10.2_nvcc_base --action_env=TF_CUDNN_VERSION=7
build:rbe_linux_cuda10.2_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-10.2"
build:rbe_linux_cuda10.2_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
build:rbe_linux_cuda10.2_nvcc_base --action_env=GCC_HOST_COMPILER_PATH="/dt7/usr/bin/gcc"
test:rbe_linux_cuda10.2_nvcc_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-10.2/lib64"
build:rbe_linux_cuda10.2_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.2-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda10.2_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.2-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda10.2_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.2-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda10.2_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.2-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda10.2_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.2-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda10.2_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.2-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda10.2_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.2-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda10.2_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.2-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda10.2_nvcc_py3.7 --config=rbe_linux_cuda10.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.2-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda10.2_nvcc_py3.7 --python_path="/usr/local/bin/python3.7"
build:rbe_linux_cuda10.2_nvcc_py3.8 --config=rbe_linux_cuda10.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.2-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda10.2_nvcc_py3.8 --python_path="/usr/local/bin/python3.8"
build:rbe_linux_cuda10.2_nvcc_py3.9 --config=rbe_linux_cuda10.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.2-cudnn7-tensorrt6.0_config_python3.9"
build:rbe_linux_cuda10.2_nvcc_py3.9 --python_path="/usr/local/bin/python3.9"

# These you may need to change for your own GCP project.
build:tensorflow_testing_rbe --project_id=tensorflow-testing
Expand Down
15 changes: 15 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://help.github.com/github/administering-a-repository/configuration-options-for-dependency-updates

version: 2
updates:
- package-ecosystem: pip
directory: /
schedule:
interval: weekly
- package-ecosystem: github-actions
directory: /
schedule:
interval: weekly
12 changes: 6 additions & 6 deletions .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
timeout-minutes: 5
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.8.0
uses: styfle/cancel-workflow-action@0.9.1
with:
access_token: ${{ github.token }}
if: ${{github.ref != 'refs/head/main'}}
Expand All @@ -25,7 +25,7 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: 3.8
- uses: pre-commit/[email protected].0
- uses: pre-commit/[email protected].3

build:
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"
Expand All @@ -34,12 +34,12 @@ jobs:
strategy:
matrix:
include:
- name-prefix: "with many tests"
- name-prefix: "with 3.8"
python-version: 3.8
os: ubuntu-latest
enable-x64: 0
package-overrides: "none"
num_generated_cases: 25
num_generated_cases: 10
use-latest-jaxlib: false
- name-prefix: "with numpy-dispatch"
python-version: 3.9
Expand Down Expand Up @@ -68,7 +68,7 @@ jobs:
use-latest-jaxlib: true
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.7.0
uses: styfle/cancel-workflow-action@0.9.1
with:
access_token: ${{ github.token }}
if: ${{github.ref != 'refs/head/main'}}
Expand Down Expand Up @@ -118,7 +118,7 @@ jobs:
python-version: [3.7]
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.7.0
uses: styfle/cancel-workflow-action@0.9.1
with:
access_token: ${{ github.token }}
if: ${{github.ref != 'refs/head/main'}}
Expand Down
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@

repos:
- repo: https://gitlab.com/pycqa/flake8
rev: '3.8.4'
rev: '3.9.2'
hooks:
- id: flake8

- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.902'
rev: 'v0.910'
hooks:
- id: mypy
files: jax/
additional_dependencies: [types-requests==0.1.11]
additional_dependencies: [types-requests==0.1.11, jaxlib==0.1.70]

- repo: https://github.com/mwouts/jupytext
rev: v1.10.0
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.20 (unreleased)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.19...main).
* Breaking Changes
* `jnp.poly*` functions now require array-like inputs ({jax-issue}`#7732`)
* `jnp.unique` and other set-like operations now require array-like inputs
({jax-issue}`#7662`)

## jax 0.2.19 (Aug 12, 2021)
* [GitHub
Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -420,15 +420,16 @@ install [CUDA](https://developer.nvidia.com/cuda-downloads) and
[CuDNN](https://developer.nvidia.com/CUDNN),
if they have not already been installed. Unlike some other popular deep
learning systems, JAX does not bundle CUDA or CuDNN as part of the `pip`
package. The CUDA 10 JAX wheels require CuDNN 7, whereas the CUDA 11 wheels of
package. JAX provides pre-built CUDA-compatible wheels for **linux only**;
the CUDA 10 JAX wheels require CuDNN 7, whereas the CUDA 11 wheels of
JAX require CuDNN 8. Other combinations of CUDA and CuDNN are possible but
require building from source.

Next, run

```bash
pip install --upgrade pip
pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Note: wheels only available on linux.
```

The jaxlib version must correspond to the version of the existing CUDA
Expand Down
27 changes: 27 additions & 0 deletions benchmarks/api_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,33 @@ def pmap_simple_8_devices(state):
d.block_until_ready()


@google_benchmark.register
@required_devices(8)
def pmap_simple_dispatch_8_devices_100_args(state):
f = jax.pmap(lambda *args: args[1:] + (args[0] + 1,))
args = []
for i in range(100):
args.append(jnp.array(list(range(i, i+8))))

args = f(*args)

while state:
args = f(*args)


@google_benchmark.register
@required_devices(8)
def pmap_simple_8_devices_100_args(state):
f = jax.pmap(lambda *args: args[1:] + (args[0] + 1,))
args = []
for i in range(100):
args.append(jnp.array(list(range(i, i+8))))

while state:
out = f(*args)
jax.tree_map(lambda x: x.block_until_ready(), out)


def _run_sda_index_bench(state, num_devices):
x = jax.pmap(jnp.sin)(jnp.arange(num_devices))
jax.device_get(x)
Expand Down
11 changes: 1 addition & 10 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def write_bazelrc(python_bin_path=None, remote_build=None,
with open("../.jax_configure.bazelrc", "w") as f:
if not remote_build and python_bin_path:
f.write(textwrap.dedent("""\
build --strategy=Genrule=standalone
build --repo_env PYTHON_BIN_PATH="{python_bin_path}"
build --action_env=PYENV_ROOT
build --python_path="{python_bin_path}"
Expand Down Expand Up @@ -374,18 +375,10 @@ def main():
"--cudnn_version",
default=None,
help="CUDNN version, e.g., 8")
parser.add_argument(
"--cuda_compute_capabilities",
default="3.5,5.2,6.0,6.1,7.0",
help="A comma-separated list of CUDA compute capabilities to support.")
parser.add_argument(
"--rocm_path",
default=None,
help="Path to the ROCm toolkit.")
parser.add_argument(
"--rocm_amdgpu_targets",
default="gfx803,gfx900,gfx906,gfx1010",
help="A comma-separated list of ROCm amdgpu targets to support.")
parser.add_argument(
"--bazel_startup_options",
action="append", default=[],
Expand Down Expand Up @@ -457,7 +450,6 @@ def main():
print("CUDA toolkit path: {}".format(cuda_toolkit_path))
if cudnn_install_path:
print("CUDNN library path: {}".format(cudnn_install_path))
print("CUDA compute capabilities: {}".format(args.cuda_compute_capabilities))
if args.cuda_version:
print("CUDA version: {}".format(args.cuda_version))
if args.cudnn_version:
Expand All @@ -470,7 +462,6 @@ def main():
if args.enable_rocm:
if rocm_toolkit_path:
print("ROCm toolkit path: {}".format(rocm_toolkit_path))
print("ROCm amdgpu targets: {}".format(args.rocm_amdgpu_targets))

write_bazelrc(
python_bin_path=python_bin_path,
Expand Down
Loading

0 comments on commit b317f6a

Please sign in to comment.