Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/google/jax
Browse files Browse the repository at this point in the history
  • Loading branch information
gmittal committed Nov 27, 2020
2 parents b2a7be9 + d328816 commit 0e80916
Show file tree
Hide file tree
Showing 93 changed files with 5,966 additions and 4,476 deletions.
1 change: 1 addition & 0 deletions .bazelversion
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.1.0
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
*.pyc
*.so
*.egg-info
*.whl
build/bazel*
dist/
.ipynb_checkpoints
/bazel-*
.bazelrc
/tensorflow
.DS_Store
build/
dist/
.mypy_cache/
.pytype/
docs/build
docs/notebooks/.ipynb_checkpoints/
docs/_autosummary
.idea
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ To cite this repository:

```
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and Skye Wanderman-Milne},
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/google/jax},
version = {0.2.5},
Expand Down
6 changes: 3 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ http_archive(
# and update the sha256 with the result.
http_archive(
name = "org_tensorflow",
sha256 = "579a74ad171d8da7b7193ff863f28482c2e6050c4090650b001fb80bbc46bb0f",
strip_prefix = "tensorflow-04f25b55e27be95ec340f414c2a1cabe16be5c2a",
sha256 = "36114803ece36c4a29cf99ca48fe3314410dcf86be3113d5e2a4ccae39c67ebd",
strip_prefix = "tensorflow-dad4331852dc31b34d45533200178b09dcfacb7a",
urls = [
"https://github.com/tensorflow/tensorflow/archive/04f25b55e27be95ec340f414c2a1cabe16be5c2a.tar.gz",
"https://github.com/tensorflow/tensorflow/archive/dad4331852dc31b34d45533200178b09dcfacb7a.tar.gz",
],
)

Expand Down
14 changes: 8 additions & 6 deletions build/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,27 @@
# JAX is Autograd and XLA

load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_not_windows")

licenses(["notice"]) # Apache 2

package(default_visibility = ["//visibility:public"])

sh_binary(
name = "install_xla_in_source_tree",
srcs = ["install_xla_in_source_tree.sh"],
py_binary(
name = "build_wheel",
srcs = ["build_wheel.py"],
data = [
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
"@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client",
"//jaxlib",
"//jaxlib:lapack.so",
"//jaxlib:_pocketfft.so",
"//jaxlib:pocketfft_flatbuffers_py",
] + if_cuda([
] + if_not_windows([
"@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client",
]) + if_cuda([
"//jaxlib:cublas_kernels",
"//jaxlib:cusolver_kernels",
"//jaxlib:cuda_prng_kernels",
]),
deps = ["@bazel_tools//tools/bash/runfiles"],
deps = ["@bazel_tools//tools/python/runfiles"],
)
8 changes: 4 additions & 4 deletions build/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ RUN /pyenv/bin/pyenv install 3.8.0
RUN /pyenv/bin/pyenv install 3.9.0

# We pin numpy to a version < 1.16 to avoid version compatibility issues.
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.6.8 && pip install numpy==1.15.4 scipy cython setuptools wheel packaging six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.7.2 && pip install numpy==1.15.4 scipy cython setuptools wheel packaging six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.8.0 && pip install numpy==1.17.3 scipy cython setuptools wheel packaging six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.9.0 && pip install numpy==1.19.4 scipy==1.5.4 cython setuptools wheel packaging six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.6.8 && pip install numpy==1.15.4 scipy==1.5.4 setuptools wheel six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.7.2 && pip install numpy==1.15.4 scipy==1.5.4 setuptools wheel six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.8.0 && pip install numpy==1.17.3 scipy==1.5.4 setuptools wheel six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.9.0 && pip install numpy==1.19.4 scipy==1.5.4 setuptools wheel six auditwheel

# Change the CUDA version if it doesn't match the installed version.
ARG JAX_CUDA_VERSION=10.0
Expand Down
121 changes: 103 additions & 18 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
# pylint: enable=g-import-not-at-top


def is_windows():
return sys.platform.startswith("win32")


def shell(cmd):
output = subprocess.check_output(cmd)
return output.decode("UTF-8").strip()
Expand All @@ -52,7 +56,8 @@ def shell(cmd):

def get_python_bin_path(python_bin_path_flag):
"""Returns the path to the Python interpreter to use."""
return python_bin_path_flag or sys.executable
path = python_bin_path_flag or sys.executable
return path.replace(os.sep, "/")


def get_python_version(python_bin_path):
Expand All @@ -71,19 +76,24 @@ def check_python_version(python_version):

# Bazel

BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/2.0.0/"
BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/3.1.0/"
BazelPackage = collections.namedtuple("BazelPackage", ["file", "sha256"])
bazel_packages = {
"Linux":
BazelPackage(
file="bazel-2.0.0-linux-x86_64",
file="bazel-3.1.0-linux-x86_64",
sha256=
"4df79462c6c3ecdeeee7af99fc269b52ab1aa4828ef3bc359c1837d3fafeeee7"),
"753434f4fa730266cf5ce21d1fdd425e1e167dd9347ad3e8adc19e8c0d54edca"),
"Darwin":
BazelPackage(
file="bazel-2.0.0-darwin-x86_64",
file="bazel-3.1.0-darwin-x86_64",
sha256=
"3eca4c96cfda97a9d5f8d3d0dec4155a5cc5ff339b10d3f35213c398bf13881e"),
"b7c5b07026eb653d431b7f15c569ecfc36a5f79427e66b5a55cab7ee885927ab"),
"Windows":
BazelPackage(
file="bazel-3.1.0-windows-x86_64.exe",
sha256=
"776db1f4986dacc3eda143932f00f7529f9ee65c7c1c004414c44aaa6419d0e9"),
}


Expand Down Expand Up @@ -131,7 +141,7 @@ def progress(block_count, block_size, total_size):
os.chmod(package.file,
st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)

return "./" + package.file
return os.path.join(".", package.file)


def get_bazel_path(bazel_path_flag):
Expand Down Expand Up @@ -184,7 +194,8 @@ def check_bazel_version(bazel_path, min_version, max_version):
build --repo_env TF_NEED_CUDA="{tf_need_cuda}"
build --action_env TF_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"
build --distinct_host_configuration=false
build --copt=-Wno-sign-compare
build:linux --copt=-Wno-sign-compare
build:macos --copt=-Wno-sign-compare
build -c opt
build:opt --copt=-march=native
build:opt --host_copt=-march=native
Expand All @@ -200,9 +211,12 @@ def check_bazel_version(bazel_path, min_version, max_version):
build --define open_source_build=true
# Disable enabled-by-default TensorFlow features that we don't care about.
build --define=no_aws_support=true
build --define=no_gcp_support=true
build --define=no_hdfs_support=true
build:linux --define=no_aws_support=true
build:macos --define=no_aws_support=true
build:linux --define=no_gcp_support=true
build:macos --define=no_gcp_support=true
build:linux --define=no_hdfs_support=true
build:macos --define=no_hdfs_support=true
build --define=no_kafka_support=true
build --define=no_ignite_support=true
build --define=grpc_no_ares=true
Expand All @@ -213,16 +227,54 @@ def check_bazel_version(bazel_path, min_version, max_version):
build --spawn_strategy=standalone
build --strategy=Genrule=standalone
build --cxxopt=-std=c++14
build --host_cxxopt=-std=c++14
build --enable_platform_specific_config
# Tensorflow uses M_* math constants that only get defined by MSVC headers if
# _USE_MATH_DEFINES is defined.
build:windows --copt=/D_USE_MATH_DEFINES
build:windows --host_copt=/D_USE_MATH_DEFINES
# Make sure to include as little of windows.h as possible
build:windows --copt=-DWIN32_LEAN_AND_MEAN
build:windows --host_copt=-DWIN32_LEAN_AND_MEAN
build:windows --copt=-DNOGDI
build:windows --host_copt=-DNOGDI
# https://devblogs.microsoft.com/cppblog/announcing-full-support-for-a-c-c-conformant-preprocessor-in-msvc/
# otherwise, there will be some compiling error due to preprocessing.
build:windows --copt=/Zc:preprocessor
build:linux --cxxopt=-std=c++14
build:linux --host_cxxopt=-std=c++14
build:macos --cxxopt=-std=c++14
build:macos --host_cxxopt=-std=c++14
build:windows --cxxopt=/std:c++14
build:windows --host_cxxopt=/std:c++14
# Generate PDB files, to generate useful PDBs, in opt compilation_mode
# --copt /Z7 is needed.
build:windows --linkopt=/DEBUG
build:windows --host_linkopt=/DEBUG
build:windows --linkopt=/OPT:REF
build:windows --host_linkopt=/OPT:REF
build:windows --linkopt=/OPT:ICF
build:windows --host_linkopt=/OPT:ICF
build:windows --experimental_strict_action_env=true
# Suppress all warning messages.
build:short_logs --output_filter=DONT_MATCH_ANYTHING
# Workaround for gcc 10+ warnings related to upb.
# See https://github.com/tensorflow/tensorflow/issues/39467
build:linux --copt=-Wno-stringop-truncation
"""



def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None, **kwargs):
def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None,
cuda_version=None, cudnn_version=None, **kwargs):
with open("../.bazelrc", "w") as f:
f.write(BAZELRC_TEMPLATE.format(**kwargs))
if cuda_toolkit_path:
Expand All @@ -231,7 +283,12 @@ def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None, **kwargs):
if cudnn_install_path:
f.write("build --action_env CUDNN_INSTALL_PATH=\"{cudnn_install_path}\"\n"
.format(cudnn_install_path=cudnn_install_path))

if cuda_version:
f.write("build --action_env TF_CUDA_VERSION=\"{cuda_version}\"\n"
.format(cuda_version=cuda_version))
if cudnn_version:
f.write("build --action_env TF_CUDNN_VERSION=\"{cudnn_version}\"\n"
.format(cudnn_version=cudnn_version))

BANNER = r"""
_ _ __ __
Expand Down Expand Up @@ -277,8 +334,9 @@ def add_boolean_argument(parser, name, default=False, help_str=None):


def main():
cwd = os.getcwd()
parser = argparse.ArgumentParser(
description="Builds libjax from source.", epilog=EPILOG)
description="Builds jaxlib from source.", epilog=EPILOG)
parser.add_argument(
"--bazel_path",
help="Path to the Bazel binary to use. The default is to find bazel via "
Expand Down Expand Up @@ -312,6 +370,14 @@ def main():
"--cudnn_path",
default=None,
help="Path to CUDNN libraries.")
parser.add_argument(
"--cuda_version",
default=None,
help="CUDA toolkit version, e.g., 11.1")
parser.add_argument(
"--cudnn_version",
default=None,
help="CUDNN version, e.g., 8")
parser.add_argument(
"--cuda_compute_capabilities",
default="3.5,5.2,6.0,6.1,7.0",
Expand All @@ -324,9 +390,21 @@ def main():
"--bazel_options",
action="append", default=[],
help="Additional options to pass to bazel.")
parser.add_argument(
"--output_path",
default=os.path.join(cwd, "dist"),
help="Directory to which the jaxlib wheel should be written")
args = parser.parse_args()

if is_windows() and args.enable_cuda:
if args.cuda_version is None:
parser.error("--cuda_version is needed for Windows CUDA build.")
if args.cudnn_version is None:
parser.error("--cudnn_version is needed for Windows CUDA build.")

print(BANNER)

output_path = os.path.abspath(args.output_path)
os.chdir(os.path.dirname(__file__ or args.prog) or '.')

# Find a working Bazel.
Expand All @@ -352,12 +430,18 @@ def main():
if cudnn_install_path:
print("CUDNN library path: {}".format(cudnn_install_path))
print("CUDA compute capabilities: {}".format(args.cuda_compute_capabilities))
if args.cuda_version:
print("CUDA version: {}".format(args.cuda_version))
if args.cudnn_version:
print("CUDNN version: {}".format(args.cudnn_version))
write_bazelrc(
python_bin_path=python_bin_path,
tf_need_cuda=1 if args.enable_cuda else 0,
cuda_toolkit_path=cuda_toolkit_path,
cudnn_install_path=cudnn_install_path,
cuda_compute_capabilities=args.cuda_compute_capabilities)
cuda_compute_capabilities=args.cuda_compute_capabilities,
cuda_version=args.cuda_version,
cudnn_version=args.cudnn_version)

print("\nBuilding XLA and installing it in the jaxlib source tree...")
config_args = args.bazel_options
Expand All @@ -371,7 +455,8 @@ def main():
config_args += ["--define=xla_python_enable_gpu=true"]
command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true"] + config_args +
[":install_xla_in_source_tree", os.getcwd()])
[":build_wheel", "--",
f"--output_path={output_path}"])
print(" ".join(command))
shell(command)
shell([bazel_path, "shutdown"])
Expand Down
26 changes: 10 additions & 16 deletions build/build_jaxlib_wheels_macos.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ set -e
# Builds wheels for multiple Python versions, using pyenv instead of Docker.
# Usage: run from root of JAX source tree as:
# build/build_jaxlib_wheels_macos.sh
# The wheels will end up in build/dist.
# The wheels will end up in dist/
#
# Requires pyenv, pyenv-virtualenv (e.g., from Homebrew). If you have Homebrew
# installed, you can install these with:
Expand All @@ -20,14 +20,11 @@ if ! pyenv --version 2>/dev/null ;then
fi
eval "$(pyenv init -)"

PLATFORM_TAG="macosx_10_9_x86_64"

build_jax () {
PY_VERSION="$1"
PY_TAG="$2"
NUMPY_VERSION="$3"
SCIPY_VERSION="$4"
echo -e "\nBuilding JAX for Python ${PY_VERSION}, tag ${PY_TAG}"
NUMPY_VERSION="$2"
SCIPY_VERSION="$3"
echo -e "\nBuilding JAX for Python ${PY_VERSION}"
echo "NumPy version ${NUMPY_VERSION}, SciPy version ${SCIPY_VERSION}"
pyenv install -s "${PY_VERSION}"
VENV="jax-build-${PY_VERSION}"
Expand All @@ -41,17 +38,14 @@ build_jax () {
# earlier Numpy versions.
pip install numpy==$NUMPY_VERSION scipy==$SCIPY_VERSION wheel future six
rm -fr build/build
python build/build.py
cd build
python setup.py bdist_wheel --python-tag "${PY_TAG}" --plat-name "${PLATFORM_TAG}"
cd ..
python build/build.py --output_path=dist/
pyenv deactivate
pyenv virtualenv-delete -f "${VENV}"
}


rm -fr build/dist
build_jax 3.6.8 cp36 1.15.4 1.2.0
build_jax 3.7.2 cp37 1.15.4 1.2.0
build_jax 3.8.0 cp38 1.17.3 1.3.2
build_jax 3.9.0 cp39 1.19.4 1.5.4
rm -fr dist
build_jax 3.6.8 1.15.4 1.2.0
build_jax 3.7.2 1.15.4 1.2.0
build_jax 3.8.0 1.17.3 1.3.2
build_jax 3.9.0 1.19.4 1.5.4
Loading

0 comments on commit 0e80916

Please sign in to comment.