Skip to content

Commit

Permalink
Don't use the trace context in the prune_closed_jaxpr_outputs cache.
Browse files Browse the repository at this point in the history
This code only manipulates jaxprs, and does not trace anything.

PiperOrigin-RevId: 648398046
  • Loading branch information
hawkinsp authored and jax authors committed Jul 1, 2024
1 parent 9653f58 commit 55589fb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,7 +1521,7 @@ def prune_closed_jaxpr_outputs(
) -> ClosedJaxpr:
return _prune_closed_jaxpr_outputs(jaxpr, tuple(used_outputs))

@weakref_lru_cache
@partial(weakref_lru_cache, trace_context_in_key=False)
def _prune_closed_jaxpr_outputs(
jaxpr: ClosedJaxpr, used_outputs: tuple[bool, ...]
) -> ClosedJaxpr:
Expand Down
7 changes: 5 additions & 2 deletions jax/_src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ def clear_all_caches():

memoize = cache(max_size=None)

def weakref_lru_cache(call: Callable, maxsize=2048):
def weakref_lru_cache(call: Callable, maxsize=2048,
trace_context_in_key: bool = True):
"""
Least recently used cache decorator with weakref support.
Expand All @@ -326,7 +327,9 @@ def weakref_lru_cache(call: Callable, maxsize=2048):
behave similar to `functools.lru_cache`.
"""
global _weakref_lru_caches
cached_call = xc.weakref_lru_cache(config.trace_context, call, maxsize)
cached_call = xc.weakref_lru_cache(
config.trace_context if trace_context_in_key else _ignore,
call, maxsize)
_weakref_lru_caches.add(cached_call)
return cached_call

Expand Down

0 comments on commit 55589fb

Please sign in to comment.