Skip to content

Commit

Permalink
Merge pull request jax-ml#7514 from skye:compilation_cache_op_name
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 389002209
  • Loading branch information
jax authors committed Aug 5, 2021
2 parents 1646dda + dcf3712 commit 03ec444
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion jax/experimental/compilation_cache/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

import hashlib
import re

import jax
from jax.experimental.compilation_cache.file_system_cache import FileSystemCache
from jax.lib import xla_client
Expand Down Expand Up @@ -68,7 +70,17 @@ def get_cache_key(xla_computation, compile_options) -> str:
"""
hash_obj = hashlib.sha256()
hash_obj.update(xla_computation.as_serialized_hlo_module_proto())
# The HLO op_name metadata sometimes includes Python function pointers,
# which cause spurious cache misses. Scrub anything that looks like a
# function pointer. Example op_name metadata:
# op_name="jit(s)/custom_jvp_call_jaxpr
# [ jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f3fa30f0940>\n
# num_consts=0 ]"
# TODO(skye): in theory this could cause us to scrub meaningful binary proto
# data. Do something more robust.
serialized_hlo = xla_computation.as_serialized_hlo_module_proto()
scrubbed_hlo = re.sub(b" at 0x[a-f0-9]+>", b" at 0x...>", serialized_hlo)
hash_obj.update(scrubbed_hlo)
if logging.vlog_is_on(1):
logging.vlog(1, f"get_cache_key hash after serializing computation: {hash_obj.digest().hex()}")
_hash_compile_options(hash_obj, compile_options)
Expand Down

0 comments on commit 03ec444

Please sign in to comment.