Skip to content

Commit

Permalink
Increase minimum jaxlib version to 0.4.20.
Browse files Browse the repository at this point in the history
jaxlib 0.4.20 has xla_extension_version 210 and mlir_api_version 54.

PiperOrigin-RevId: 609094229
  • Loading branch information
hawkinsp authored and jax authors committed Feb 21, 2024
1 parent 5da43a4 commit aad02db
Show file tree
Hide file tree
Showing 7 changed files with 3 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Remember to align the itemized text with the first line of an item within a list
* Importing the `jax.config` submodule via `import jax.config` is deprecated.
To configure JAX use `import jax` and then reference the config object
via `jax.config`.
* The minimum jaxlib version is now 0.4.20.

## jaxlib 0.4.25

Expand Down
12 changes: 0 additions & 12 deletions jax/_src/cloud_tpu_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import os
import warnings
from jax._src import hardware_utils

running_in_cloud_tpu_vm: bool = False
Expand Down Expand Up @@ -69,14 +68,3 @@ def cloud_tpu_init() -> None:
os.environ['TPU_ML_PLATFORM'] = 'JAX'
if hardware_utils.tpu_enhanced_barrier_supported():
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true"

# TODO(skyewm): remove this warning at some point, say around Sept 2023.
use_pjrt_c_api = os.environ.get('JAX_USE_PJRT_C_API_ON_TPU', None)
if use_pjrt_c_api:
warnings.warn(
"JAX_USE_PJRT_C_API_ON_TPU no longer has an effect (the new TPU "
"runtime is always enabled now). Unset the environment variable "
"to disable this warning.")

# Remove when minimum jaxlib version is >= 0.4.15
os.environ['JAX_USE_PJRT_C_API_ON_TPU'] = "true"
2 changes: 1 addition & 1 deletion jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,7 @@ def _update_disable_jit_thread_local(val):
# TODO(parkers): Remove if there are no complaints.
remat_opt_barrier = define_bool_state(
name='jax_remat_opt_barrier',
default=(lib.version >= (0, 3, 6)),
default=True,
help=('Enables using optimization-barrier op for lowering remat.'))

# TODO(sharadmv,mattjj): set default to True, then remove
Expand Down
2 changes: 1 addition & 1 deletion jax/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def make_release_tree(self, base_dir, files):


__version__ = _get_version_string()
_minimum_jaxlib_version = "0.4.19"
_minimum_jaxlib_version = "0.4.20"

def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
Expand Down
2 changes: 0 additions & 2 deletions tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.util import safe_zip
from jax._src.sharding_impls import (_op_sharding_to_pos_sharding,
pmap_sharding_devices_indices_map)
Expand Down Expand Up @@ -790,7 +789,6 @@ def test_fully_replicated_donated_array_is_deleted(self):
self.assertTrue(arr.is_deleted())

@parameterized.product(dtype=jtu.dtypes.all + jtu.dtypes.custom_floats)
@unittest.skipIf(xla_extension_version < 208, "Test requires jaxlib > 0.4.19")
def test_shards_have_correct_dtype(self, dtype):
x = jnp.ones((), dtype=dtype)
for shard in x.addressable_shards:
Expand Down
4 changes: 0 additions & 4 deletions tests/cache_key_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version


config.parse_flags_with_absl()
Expand Down Expand Up @@ -69,9 +68,6 @@ def test_serialized_compile_options(self):

@jtu.skip_on_devices("cpu")
def test_hash_accelerator_devices(self):
if xla_extension_version < 209 and xla_bridge.using_pjrt_c_api():
raise unittest.SkipTest("PjRt C API not yet supported.")

devices = np.array([[jax.local_devices()[0]]])

dev_hash1 = self.get_hashed_value(cache_key._hash_devices, devices)
Expand Down
2 changes: 0 additions & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_extension_version
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, implements
from jax._src.util import safe_zip, NumpyComplexWarning

Expand Down Expand Up @@ -3787,7 +3786,6 @@ def testAstype(self, from_dtype, to_dtype, use_method):
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)

@unittest.skipIf(xla_extension_version < 210, 'jaxlib version too old')
def testAstypeInt4(self):
# Test converting from int4 to int8
x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)
Expand Down

0 comments on commit aad02db

Please sign in to comment.