Skip to content

Commit

Permalink
Update xla_client._version and add missing version checks to JAX
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 449021408
  • Loading branch information
skye authored and jax authors committed May 16, 2022
1 parent bd20f0f commit 744f6b4
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions jax/experimental/compilation_cache/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ def _hash_computation(hash_obj, xla_computation):
hash_obj.update(scrubbed_hlo)

def _hash_compile_options(hash_obj, compile_options_obj):
assert len(dir(compile_options_obj)) == 32, (
if xla_client._version >= 68: # Remove when minimum jaxlib version >= 0.3.11
expected_num_compile_options = 32
else:
expected_num_compile_options = 31
assert len(dir(compile_options_obj)) == expected_num_compile_options, (
f"Unexpected number of CompileOption fields: "
f"{len(dir(compile_options_obj))}. This likely: means that an extra "
f"field was added, and this function needs to be updated.")
Expand All @@ -126,7 +130,8 @@ def _hash_compile_options(hash_obj, compile_options_obj):
_hash_bool(hash_obj, compile_options_obj.tuple_arguments)
_hash_int(hash_obj, compile_options_obj.num_replicas)
_hash_int(hash_obj, compile_options_obj.num_partitions)
_hash_int(hash_obj, compile_options_obj.profile_version)
if xla_client._version >= 68: # Remove when minimum jaxlib version >= 0.3.11
_hash_int(hash_obj, compile_options_obj.profile_version)
if compile_options_obj.device_assignment is not None:
hash_obj.update(compile_options_obj.device_assignment.serialize())

Expand Down

0 comments on commit 744f6b4

Please sign in to comment.