Skip to content

Commit

Permalink
cleanup old jaxlib version check
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 7, 2023
1 parent fa099fd commit c3d3c19
Showing 1 changed file with 1 addition and 7 deletions.
8 changes: 1 addition & 7 deletions jax/_src/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,7 @@ def _xla_gc_callback(*args):
xla_extension_version: int = getattr(xla_client, '_version', 0)

import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error
if jaxlib.version.__version_info__ >= (0, 4, 11):
# TODO(sharadmv): make this unconditional when minimum jaxlib version is
# bumped to 0.4.11
try:
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error
except ModuleNotFoundError:
pass
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error

# Version number for MLIR:Python APIs, provided by jaxlib.
mlir_api_version = xla_client.mlir_api_version
Expand Down

0 comments on commit c3d3c19

Please sign in to comment.