Skip to content

Commit

Permalink
Add option to key on extra command-line flags in persistent compilati…
Browse files Browse the repository at this point in the history
…on cache.

PiperOrigin-RevId: 409437212
  • Loading branch information
skye authored and jax authors committed Nov 12, 2021
1 parent 09ccd90 commit 649eab3
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion jax/experimental/compilation_cache/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import re
import sys
from typing import List

import jax
from jax.experimental.compilation_cache.file_system_cache import FileSystemCache
Expand Down Expand Up @@ -173,14 +174,19 @@ def _hash_platform(hash_obj, backend):
"--xla_dump_hlo_pipeline_re",
]

extra_flag_prefixes_to_include_in_cache_key: List[str] = []

def _hash_xla_flags(hash_obj):
xla_flags = []

xla_flags_env_var = os.getenv("XLA_FLAGS")
if xla_flags_env_var:
xla_flags.extend(xla_flags_env_var.split())

xla_flags.extend(arg for arg in sys.argv if arg.startswith("--xla_"))
for arg in sys.argv:
if arg.startswith("--xla") or any(
arg.startswith(p) for p in extra_flag_prefixes_to_include_in_cache_key):
xla_flags.append(arg)

# N.B. all XLA flags that take an argument must use '=' and not a space
# (e.g. --xla_force_host_platform_device_count=8) (I think).
Expand Down

0 comments on commit 649eab3

Please sign in to comment.