Skip to content

Commit

Permalink
Change Linux docker wheel build to build separate wheels for each min…
Browse files Browse the repository at this point in the history
…or Python version.

* Use pyenv to build specific Python releases against which we build wheels. Otherwise we are limited to only those Python releases present in the OS repository, which only correspond to certain major versions of Python.
* Build with docker instead of nvidia-docker. We need a CUDA image to build JAX, but we don't require a GPU or nvidia-docker unless we want to actually run things.
* Various other cleanups.
  • Loading branch information
hawkinsp authored and mattjj committed Jan 18, 2019
1 parent b2fe46f commit aaff39a
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 23 deletions.
22 changes: 18 additions & 4 deletions build/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,30 @@ LABEL maintainer "Matt Johnson <[email protected]>"

RUN apt-get update && apt-get install -y --no-install-recommends \
dh-autoreconf git curl \
python python-pip python-dev \
python3 python3-pip python3-dev
RUN pip install numpy scipy cython setuptools wheel && pip3 install numpy scipy cython setuptools wheel
build-essential libssl-dev zlib1g-dev libbz2-dev libreadline-dev \
libsqlite3-dev wget llvm libncurses5-dev xz-utils tk-dev \
libxml2-dev libxmlsec1-dev libffi-dev

RUN git clone https://github.com/nixos/patchelf /tmp/patchelf
WORKDIR /tmp/patchelf
RUN bash bootstrap.sh && ./configure && make && make install && rm -r /tmp/patchelf


WORKDIR /
RUN git clone https://github.com/pyenv/pyenv.git /pyenv
ENV PYENV_ROOT /pyenv
RUN /pyenv/bin/pyenv install 2.7.15
RUN /pyenv/bin/pyenv install 3.6.8
RUN /pyenv/bin/pyenv install 3.7.2

# We pin numpy to a version < 1.16 to avoid version compatibility issues.
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 2.7.15 && pip install numpy==1.15.4 scipy cython setuptools wheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.6.8 && pip install numpy==1.15.4 scipy cython setuptools wheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.7.2 && pip install numpy==1.15.4 scipy cython setuptools wheel


WORKDIR /
RUN curl -O https://raw.githubusercontent.com/google/jax/762abcf29b4a155c3de325c27ecffa5d4a3da28c/build/build_wheel_docker_entrypoint.sh
COPY build_wheel_docker_entrypoint.sh /build_wheel_docker_entrypoint.sh
RUN chmod +x /build_wheel_docker_entrypoint.sh

WORKDIR /build
Expand Down
18 changes: 10 additions & 8 deletions build/build_jaxlib_wheels.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
#!/bin/bash
set -xev
JAXLIB_VERSION=$(sed -n "s/^ \+version=[']\(.*\)['],$/\\1/p" jax/build/setup.py)

PYTHON_VERSIONS="py2 py3"
PYTHON_VERSIONS="2.7.15 3.6.8 3.7.2"
CUDA_VERSIONS="9.0 9.2 10.0"
CUDA_VARIANTS="cuda" # "cuda cuda-included"
CUDA_VARIANTS="cuda" # "cuda-included"

mkdir -p dist

Expand All @@ -13,8 +12,11 @@ docker build -t jaxbuild jax/build/
for PYTHON_VERSION in $PYTHON_VERSIONS
do
mkdir -p dist/nocuda/
nvidia-docker run -it --tmpfs /build:exec --rm -v $(pwd)/dist:/dist jaxbuild $PYTHON_VERSION nocuda
mv dist/*.whl dist/nocuda/jaxlib-${JAXLIB_VERSION}-${PYTHON_VERSION}-none-manylinux1_x86_64.whl
docker run -it --tmpfs /build:exec --rm -v $(pwd)/dist:/dist jaxbuild $PYTHON_VERSION nocuda
for I in $(find dist/nocuda/ -name "*.whl")
do
mv $I $(echo $I | sed -e 's/linux_x86_64/manylinux1_x86_64/')
done
done

# build the cuda linux packages, tagging with linux_x86_64
Expand All @@ -26,8 +28,8 @@ do
for CUDA_VARIANT in $CUDA_VARIANTS
do
mkdir -p dist/${CUDA_VARIANT}${CUDA_VERSION//.}
nvidia-docker run -it --tmpfs /build:exec --rm -v $(pwd)/dist:/dist jaxbuild $PYTHON_VERSION $CUDA_VARIANT
mv dist/*.whl dist/${CUDA_VARIANT}${CUDA_VERSION//.}/jaxlib-${JAXLIB_VERSION}-${PYTHON_VERSION}-none-linux_x86_64.whl
docker run -it --tmpfs /build:exec --rm -v $(pwd)/dist:/dist jaxbuild $PYTHON_VERSION $CUDA_VARIANT
mv dist/*.whl dist/${CUDA_VARIANT}${CUDA_VERSION//.}/
done
done
done
done
2 changes: 1 addition & 1 deletion build/build_jaxlib_wheels_macos.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ build_jax () {
rm -fr build/dist
build_jax 2.7.15 cp27
build_jax 3.6.8 cp36
build_jax 3.7.2 cp37
build_jax 3.7.2 cp37
23 changes: 13 additions & 10 deletions build/build_wheel_docker_entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ then
exit 1
fi

export PYENV_ROOT="/pyenv"
export PATH="$PYENV_ROOT/bin:$PATH"
eval "$(pyenv init -)"

PY_VERSION="$1"
echo "Python version $PY_VERSION"

git clone https://github.com/google/jax /build/jax
cd /build/jax/build

Expand All @@ -19,15 +26,11 @@ then
usage
fi

case $1 in
py3)
update-alternatives --install /usr/bin/python python /usr/bin/python3 10
;;
py2)
;;
*)
usage
esac
# Builds and activates a specific Python version.
pyenv local "$PY_VERSION"

PY_TAG=$(python -c "import wheel; import wheel.pep425tags as t; print(t.get_abbr_impl() + t.get_impl_ver())")
echo "Python tag $PY_TAG"

case $2 in
cuda-included)
Expand All @@ -44,5 +47,5 @@ case $2 in
usage
esac

python setup.py bdist_wheel
python setup.py bdist_wheel --python-tag "$PY_TAG" --plat-name "linux_x86_64"
cp -r dist/* /dist

0 comments on commit aaff39a

Please sign in to comment.