Skip to content

Commit

Permalink
Change jaxlib build rules to build a wheel, rather than writing outpu…
Browse files Browse the repository at this point in the history
…t to the source directory.
  • Loading branch information
hawkinsp committed Nov 20, 2020
1 parent 5a41779 commit c06ead6
Show file tree
Hide file tree
Showing 11 changed files with 224 additions and 163 deletions.
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
*.pyc
*.so
*.egg-info
*.whl
build/bazel*
dist/
.ipynb_checkpoints
/bazel-*
.bazelrc
/tensorflow
.DS_Store
build/
dist/
.mypy_cache/
.pytype/
docs/build
docs/notebooks/.ipynb_checkpoints/
docs/_autosummary
.idea
Expand Down
4 changes: 2 additions & 2 deletions build/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ licenses(["notice"]) # Apache 2
package(default_visibility = ["//visibility:public"])

py_binary(
name = "install_xla_in_source_tree",
srcs = ["install_xla_in_source_tree.py"],
name = "build_wheel",
srcs = ["build_wheel.py"],
data = [
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
"//jaxlib",
Expand Down
8 changes: 4 additions & 4 deletions build/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ RUN /pyenv/bin/pyenv install 3.8.0
RUN /pyenv/bin/pyenv install 3.9.0

# We pin numpy to a version < 1.16 to avoid version compatibility issues.
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.6.8 && pip install numpy==1.15.4 scipy cython setuptools wheel packaging six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.7.2 && pip install numpy==1.15.4 scipy cython setuptools wheel packaging six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.8.0 && pip install numpy==1.17.3 scipy cython setuptools wheel packaging six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.9.0 && pip install numpy==1.19.4 scipy==1.5.4 cython setuptools wheel packaging six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.6.8 && pip install numpy==1.15.4 scipy==1.5.4 setuptools wheel six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.7.2 && pip install numpy==1.15.4 scipy==1.5.4 setuptools wheel six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.8.0 && pip install numpy==1.17.3 scipy==1.5.4 setuptools wheel six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.9.0 && pip install numpy==1.19.4 scipy==1.5.4 setuptools wheel six auditwheel

# Change the CUDA version if it doesn't match the installed version.
ARG JAX_CUDA_VERSION=10.0
Expand Down
12 changes: 10 additions & 2 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,9 @@ def add_boolean_argument(parser, name, default=False, help_str=None):


def main():
cwd = os.getcwd()
parser = argparse.ArgumentParser(
description="Builds libjax from source.", epilog=EPILOG)
description="Builds jaxlib from source.", epilog=EPILOG)
parser.add_argument(
"--bazel_path",
help="Path to the Bazel binary to use. The default is to find bazel via "
Expand Down Expand Up @@ -388,6 +389,10 @@ def main():
"--bazel_options",
action="append", default=[],
help="Additional options to pass to bazel.")
parser.add_argument(
"--output_path",
default=os.path.join(cwd, "dist"),
help="Directory to which the jaxlib wheel should be written")
args = parser.parse_args()

if is_windows() and args.enable_cuda:
Expand All @@ -397,6 +402,8 @@ def main():
parser.error("--cudnn_version is needed for Windows CUDA build.")

print(BANNER)

output_path = os.path.abspath(args.output_path)
os.chdir(os.path.dirname(__file__ or args.prog) or '.')

# Find a working Bazel.
Expand Down Expand Up @@ -447,7 +454,8 @@ def main():
config_args += ["--define=xla_python_enable_gpu=true"]
command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true"] + config_args +
[":install_xla_in_source_tree", os.getcwd()])
[":build_wheel", "--",
f"--output_path={output_path}"])
print(" ".join(command))
shell(command)
shell([bazel_path, "shutdown"])
Expand Down
26 changes: 10 additions & 16 deletions build/build_jaxlib_wheels_macos.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ set -e
# Builds wheels for multiple Python versions, using pyenv instead of Docker.
# Usage: run from root of JAX source tree as:
# build/build_jaxlib_wheels_macos.sh
# The wheels will end up in build/dist.
# The wheels will end up in dist/
#
# Requires pyenv, pyenv-virtualenv (e.g., from Homebrew). If you have Homebrew
# installed, you can install these with:
Expand All @@ -20,14 +20,11 @@ if ! pyenv --version 2>/dev/null ;then
fi
eval "$(pyenv init -)"

PLATFORM_TAG="macosx_10_9_x86_64"

build_jax () {
PY_VERSION="$1"
PY_TAG="$2"
NUMPY_VERSION="$3"
SCIPY_VERSION="$4"
echo -e "\nBuilding JAX for Python ${PY_VERSION}, tag ${PY_TAG}"
NUMPY_VERSION="$2"
SCIPY_VERSION="$3"
echo -e "\nBuilding JAX for Python ${PY_VERSION}"
echo "NumPy version ${NUMPY_VERSION}, SciPy version ${SCIPY_VERSION}"
pyenv install -s "${PY_VERSION}"
VENV="jax-build-${PY_VERSION}"
Expand All @@ -41,17 +38,14 @@ build_jax () {
# earlier Numpy versions.
pip install numpy==$NUMPY_VERSION scipy==$SCIPY_VERSION wheel future six
rm -fr build/build
python build/build.py
cd build
python setup.py bdist_wheel --python-tag "${PY_TAG}" --plat-name "${PLATFORM_TAG}"
cd ..
python build/build.py --output_path=dist/
pyenv deactivate
pyenv virtualenv-delete -f "${VENV}"
}


rm -fr build/dist
build_jax 3.6.8 cp36 1.15.4 1.2.0
build_jax 3.7.2 cp37 1.15.4 1.2.0
build_jax 3.8.0 cp38 1.17.3 1.3.2
build_jax 3.9.0 cp39 1.19.4 1.5.4
rm -fr dist
build_jax 3.6.8 1.15.4 1.2.0
build_jax 3.7.2 1.15.4 1.2.0
build_jax 3.8.0 1.17.3 1.3.2
build_jax 3.9.0 1.19.4 1.5.4
172 changes: 172 additions & 0 deletions build/build_wheel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Script that builds a jaxlib wheel, intended to be run via bazel run as part
# of the jaxlib build process.

# Most users should not run this script directly; use build.py instead.

import argparse
import functools
import glob
import os
import platform
import shutil
import subprocess
import sys
import tempfile

from bazel_tools.tools.python.runfiles import runfiles

parser = argparse.ArgumentParser()
parser.add_argument(
"--sources_path",
default=None,
help="Path in which the wheel's sources should be prepared. Optional. If "
"omitted, a temporary directory will be used.")
parser.add_argument(
"--output_path",
default=None,
required=True,
help="Path to which the output wheel should be written. Required.")
args = parser.parse_args()

r = runfiles.Create()


def _is_windows():
return sys.platform.startswith("win32")


def _copy_so(src_file, dst_dir, dst_filename=None):
src_filename = os.path.basename(src_file)
if not dst_filename:
if _is_windows() and src_filename.endswith(".so"):
dst_filename = src_filename[:-3] + ".pyd"
else:
dst_filename = src_filename
dst_file = os.path.join(dst_dir, dst_filename)
shutil.copy(src_file, dst_file)


def _copy_normal(src_file, dst_dir, dst_filename=None):
src_filename = os.path.basename(src_file)
dst_file = os.path.join(dst_dir, dst_filename or src_filename)
shutil.copy(src_file, dst_file)


def copy_file(src_file, dst_dir, dst_filename=None):
if src_file.endswith(".so"):
_copy_so(src_file, dst_dir, dst_filename=dst_filename)
else:
_copy_normal(src_file, dst_dir, dst_filename=dst_filename)

def patch_copy_xla_client_py(dst_dir):
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_client.py")) as f:
src = f.read()
src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla",
"from . import xla_extension as _xla")
with open(os.path.join(dst_dir, "xla_client.py"), "w") as f:
f.write(src)


def patch_copy_tpu_client_py(dst_dir):
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py")) as f:
src = f.read()
src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla",
"from . import xla_extension as _xla")
src = src.replace("from tensorflow.compiler.xla.python import xla_client",
"from . import xla_client")
src = src.replace(
"from tensorflow.compiler.xla.python.tpu_driver.client import tpu_client_extension as _tpu_client",
"from . import tpu_client_extension as _tpu_client")
with open(os.path.join(dst_dir, "tpu_client.py"), "w") as f:
f.write(src)


def prepare_wheel(sources_path):
"""Assembles a source tree for the wheel in `sources_path`."""
jaxlib_dir = os.path.join(sources_path, "jaxlib")
os.makedirs(jaxlib_dir)
copy_to_jaxlib = functools.partial(copy_file, dst_dir=jaxlib_dir)

copy_file(r.Rlocation("__main__/jaxlib/setup.py"), dst_dir=sources_path)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/init.py"), dst_filename="__init__.py")
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/lapack.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_pocketfft.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/pocketfft_flatbuffers_py_generated.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/pocketfft.py"))
if r.Rlocation("__main__/jaxlib/cusolver_kernels.so") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusolver_kernels.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cublas_kernels.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusolver_kernels.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cuda_prng_kernels.so"))
if r.Rlocation("__main__/jaxlib/cusolver_kernels.pyd") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusolver_kernels.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cublas_kernels.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusolver_kernels.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cuda_prng_kernels.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/version.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusolver.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cuda_prng.py"))

if _is_windows():
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.pyd"))
else:
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so"))
patch_copy_xla_client_py(jaxlib_dir)

if not _is_windows():
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so"))
patch_copy_tpu_client_py(jaxlib_dir)


def build_wheel(sources_path, output_path):
"""Builds a wheel in `output_path` using the source tree in `sources_path`."""
platform_name = {
"Linux": "manylinux2010",
"Darwin": "macosx_10_9",
"Windows": "win",
}[platform.system()]
cpu_name = "amd64" if platform.system() == "Windows" else "x86_64"
python_tag_arg = (f"--python-tag=cp{sys.version_info.major}"
f"{sys.version_info.minor}")
platform_tag_arg = f"--plat-name={platform_name}_{cpu_name}"
cwd = os.getcwd()
os.chdir(sources_path)
subprocess.run([sys.executable, "setup.py", "bdist_wheel",
python_tag_arg, platform_tag_arg])
os.chdir(cwd)
for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")):
output_file = os.path.join(output_path, os.path.basename(wheel))
sys.stderr.write(f"Output wheel: {output_file}\n\n")
sys.stderr.write(f"To install the newly-built jaxlib wheel, run:\n")
sys.stderr.write(f" pip install {output_file}\n\n")
shutil.copy(wheel, output_path)


tmpdir = None
sources_path = args.sources_path
if sources_path is None:
tmpdir = tempfile.TemporaryDirectory(prefix="jaxlib")
sources_path = tmpdir.name

try:
os.makedirs(args.output_path, exist_ok=True)
prepare_wheel(sources_path)
build_wheel(sources_path, args.output_path)
finally:
if tmpdir:
tmpdir.cleanup()

10 changes: 1 addition & 9 deletions build/build_wheel_docker_entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,33 +35,25 @@ fi
# Builds and activates a specific Python version.
pyenv local "$PY_VERSION"

PY_TAG=$(python -c "import packaging.tags as t; print(t.interpreter_name() + t.interpreter_version())")

echo "Python tag: $PY_TAG"

# Workaround for https://github.com/bazelbuild/bazel/issues/9254
export BAZEL_LINKLIBS="-lstdc++"

export JAX_CUDA_VERSION=$3
case $2 in
cuda-included)
python build.py --enable_cuda --bazel_startup_options="--output_user_root=/build/root"
python include_cuda.py
PLAT_NAME="manylinux2010_x86_64"
;;
cuda)
python build.py --enable_cuda --bazel_startup_options="--output_user_root=/build/root"
PLAT_NAME="manylinux2010_x86_64"
;;
nocuda)
python build.py --bazel_startup_options="--output_user_root=/build/root"
PLAT_NAME="manylinux2010_x86_64"
;;
*)
usage
esac

export JAX_CUDA_VERSION=$3
python setup.py bdist_wheel --python-tag "$PY_TAG" --plat-name "$PLAT_NAME"
if ! python -m auditwheel show dist/jaxlib-*.whl | grep 'platform tag: "manylinux2010_x86_64"' > /dev/null; then
# Print output for debugging
python -m auditwheel show dist/jaxlib-*.whl
Expand Down
Loading

0 comments on commit c06ead6

Please sign in to comment.