Skip to content

Commit

Permalink
Specify minimum jaxlib version in a single location
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 22, 2021
1 parent a1cf066 commit f9a4162
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 17 deletions.
4 changes: 2 additions & 2 deletions build/test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
flake8
# For now, we pin the numpy version here
numpy>=1.16
# Must be kept in sync with the minimum jaxlib version in jax/lib/__init__.py
jaxlib==0.1.62
mypy==0.790
pillow
pytest-benchmark
pytest-xdist
wheel
# Install jax from the current directory; minimum required jaxlib from pypi.
.[minimum-jaxlib]
17 changes: 7 additions & 10 deletions jax/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,30 @@
'https://github.com/google/jax#installation for installation instructions.'
) from err

# Must be kept in sync with the jaxlib versions in
# - setup.py
# - build/test-requirements.txt
_minimum_jaxlib_version = (0, 1, 62)
from jax.version import _minimum_jaxlib_version as _minimum_jaxlib_version_str
try:
from jaxlib import version as jaxlib_version
except Exception as err:
# jaxlib is too old to have version number.
msg = 'This version of jax requires jaxlib version >= {}.'
raise ImportError(msg.format('.'.join(map(str, _minimum_jaxlib_version)))
) from err
msg = f'This version of jax requires jaxlib version >= {_minimum_jaxlib_version_str}.'
raise ImportError(msg) from err

version = tuple(int(x) for x in jaxlib_version.__version__.split('.'))
_minimum_jaxlib_version = tuple(int(x) for x in _minimum_jaxlib_version_str.split('.'))

# Check the jaxlib version before importing anything else from jaxlib.
def _check_jaxlib_version():
if version < _minimum_jaxlib_version:
msg = 'jaxlib is version {}, but this version of jax requires version {}.'
msg = (f'jaxlib is version {jaxlib_version.__version__}, '
f'but this version of jax requires version {_minimum_jaxlib_version_str}.')

if version == (0, 1, 23):
msg += ('\n\nA common cause of this error is that you installed jaxlib '
'using pip, but your version of pip is too old to support '
'manylinux2010 wheels. Try running:\n\n'
'pip install --upgrade pip\n'
'pip install --upgrade jax jaxlib\n')
raise ValueError(msg.format('.'.join(map(str, version)),
'.'.join(map(str, _minimum_jaxlib_version))))
raise ValueError(msg)

_check_jaxlib_version()

Expand Down
1 change: 1 addition & 0 deletions jax/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

__version__ = "0.2.10"
_minimum_jaxlib_version = "0.1.62"
7 changes: 5 additions & 2 deletions jaxlib/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# This should be increased after releasing the current version (i.e. this
# is always the next version to be released).
# After a new jaxlib release, please remember to update the values of
# `_current_jaxlib_version` and `_available_cuda_versions` in setup.py to
# reflect the most recent available binaries.
# __version__ should be increased after releasing the current version
# (i.e. on master, this is always the next version to be released).
__version__ = "0.1.65"
11 changes: 8 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

from setuptools import setup, find_packages

__version__ = None
_minimum_jaxlib_version = '0.1.62'
# The following should be updated with each new jaxlib release.
_current_jaxlib_version = '0.1.64'
_available_cuda_versions = ['101', '102', '110', '111', '112']

_dct = {}
with open('jax/version.py') as f:
exec(f.read(), globals())
exec(f.read(), _dct)
__version__ = _dct['__version__']
_minimum_jaxlib_version = _dct['_minimum_jaxlib_version']

setup(
name='jax',
Expand All @@ -37,6 +39,9 @@
'opt_einsum',
],
extras_require={
# Minimum jaxlib version; used in testing.
'minimum-jaxlib': [f'jaxlib=={_minimum_jaxlib_version}'],

# CPU-only jaxlib can be installed via:
# $ pip install jax[cpu]
'cpu': [f'jaxlib>={_minimum_jaxlib_version}'],
Expand Down

0 comments on commit f9a4162

Please sign in to comment.