Skip to content

Commit

Permalink
Include XLA_FLAGS in persistent compilation cache key.
Browse files Browse the repository at this point in the history
This is to prevent false cache hits when the compiler behavior is
changed via flags. Flags known to not affect the compiled executable
(e.g. dumping HLO) are excluded from the key.

Note that any XLA flags with arguments should use = and not a space,
e.g. `--xla_flag=value`, not `--xla_flag value`. I believe this is
already a requirement of ABSL flags in general, but I'm not 100% sure.

Also note that this doesn't currently support XLA flags specified via
--flagfile. Please file a feature request if this is needed.
  • Loading branch information
skye committed Oct 6, 2021
1 parent 50141e6 commit f939048
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 8 deletions.
62 changes: 54 additions & 8 deletions jax/experimental/compilation_cache/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.

import hashlib
import os
import re
import sys

import jax
from jax.experimental.compilation_cache.file_system_cache import FileSystemCache
Expand Down Expand Up @@ -57,6 +59,11 @@ def put_executable(xla_computation, compile_options, executable: xla_client.Exec
serialized_executable = backend.serialize_executable(executable)
_cache.put(cache_key, serialized_executable)

def _log_cache_key_hash(hash_obj, last_serialized: str):
if logging.vlog_is_on(1):
logging.vlog(1, "get_cache_key hash after serializing %s: %s",
last_serialized, hash_obj.digest().hex())

def get_cache_key(xla_computation, compile_options, backend) -> str:
"""Creates a hashed string to use as a key to the compilation cache.
Expand All @@ -79,17 +86,20 @@ def get_cache_key(xla_computation, compile_options, backend) -> str:
serialized_hlo = xla_computation.as_serialized_hlo_module_proto()
scrubbed_hlo = re.sub(b" at 0x[a-f0-9]+>", b" at 0x...>", serialized_hlo)
hash_obj.update(scrubbed_hlo)
if logging.vlog_is_on(1):
logging.vlog(1, f"get_cache_key hash after serializing computation: {hash_obj.digest().hex()}")
_log_cache_key_hash(hash_obj, "computation")

_hash_compile_options(hash_obj, compile_options)
if logging.vlog_is_on(1):
logging.vlog(1, f"get_cache_key hash after serializing compile_options: {hash_obj.digest().hex()}")
_log_cache_key_hash(hash_obj, "compile_options")

hash_obj.update(bytes(jax._src.lib.version))
if logging.vlog_is_on(1):
logging.vlog(1, f"get_cache_key hash after serializing jax_lib version: {hash_obj.digest().hex()}")
_log_cache_key_hash(hash_obj, "jax_lib version")

_hash_platform(hash_obj, backend)
if logging.vlog_is_on(1):
logging.vlog(1, f"get_cache_key hash after serializing the backend: {hash_obj.digest().hex()}")
_log_cache_key_hash(hash_obj, "the backend")

_hash_xla_flags(hash_obj)
_log_cache_key_hash(hash_obj, "XLA flags")

return hash_obj.digest().hex()

def _hash_compile_options(hash_obj, compile_options_obj):
Expand Down Expand Up @@ -145,6 +155,42 @@ def _hash_platform(hash_obj, backend):
_hash_string(hash_obj, backend.platform_version)
_hash_string(hash_obj, backend.runtime_type)

_xla_flags_to_exclude_from_cache_key = [
"--xla_dump_compress_protos",
"--xla_dump_module_metadata",
"--xla_dump_max_hlo_modules",
"--xla_dump_include_timestamp",
"--xla_dump_hlo_pass_re",
"--xla_dump_hlo_module_re",
"--xla_dump_hlo_snapshots",
"--xla_dump_fusion_visualization",
"--xla_dump_hlo_as_url",
"--xla_dump_hlo_as_proto",
"--xla_dump_hlo_as_text",
"--xla_dump_to",
"--xla_force_host_platform_device_count",
"--xla_dump_disable_metadata",
"--xla_dump_hlo_pipeline_re",
]

def _hash_xla_flags(hash_obj):
xla_flags = []

xla_flags_env_var = os.getenv("XLA_FLAGS")
if xla_flags_env_var:
xla_flags.extend(xla_flags_env_var.split())

xla_flags.extend(arg for arg in sys.argv if arg.startswith("--xla_"))

# N.B. all XLA flags that take an argument must use '=' and not a space
# (e.g. --xla_force_host_platform_device_count=8) (I think).
for flag in xla_flags:
if flag.split('=')[0] in _xla_flags_to_exclude_from_cache_key:
logging.vlog(1, "Not including XLA flag in cache key: %s", flag)
continue
logging.vlog(1, "Including XLA flag in cache key: %s", flag)
_hash_string(hash_obj, flag)

def _hash_int(hash_obj, int_var):
hash_obj.update(int_var.to_bytes(8, byteorder='big'))

Expand Down
41 changes: 41 additions & 0 deletions tests/compilation_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import hashlib
import os
import random
import sys
import tempfile
import unittest
from unittest import SkipTest
Expand Down Expand Up @@ -141,6 +142,46 @@ def test_different_computations(self):
self.assertNotEqual(cc.get_cache_key(computation1, compile_options, backend),
cc.get_cache_key(computation2, compile_options, backend))

def test_xla_flags(self):
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
compile_options = jax._src.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
backend = jax._src.lib.xla_bridge.get_backend()

orig_xla_flags = os.getenv("XLA_FLAGS")
orig_argv = sys.argv
try:
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
key1 = cc.get_cache_key(computation, compile_options, backend)
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=1"
key2 = cc.get_cache_key(computation, compile_options, backend)
self.assertNotEqual(key1, key2)

os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
key3 = cc.get_cache_key(computation, compile_options, backend)
self.assertEqual(key1, key3)

# Test flag in _xla_flags_to_exclude_from_cache_key
os.environ["XLA_FLAGS"] = (
"--xla_gpu_autotune_level=0 --xla_force_host_platform_device_count=8")
key4 = cc.get_cache_key(computation, compile_options, backend)
self.assertEqual(key1, key4)

# Test flags given on command line
del os.environ["XLA_FLAGS"]
sys.argv.append("--xla_gpu_autotune_level=0")
key5 = cc.get_cache_key(computation, compile_options, backend)
self.assertEqual(key1, key5)
sys.argv.append("--xla_force_host_platform_device_count=8")
self.assertEqual(key1, key5)

finally:
if orig_xla_flags is not None:
os.environ["XLA_FLAGS"] = orig_xla_flags
elif os.getenv("XLA_FLAGS") is not None:
del os.environ["XLA_FLAGS"]
sys.argv = orig_argv

def test_get_no_executable(self):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
Expand Down

0 comments on commit f939048

Please sign in to comment.