From 55589fbf41031ef59a8f2f81a559e47636890c4d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 1 Jul 2024 09:40:43 -0700 Subject: [PATCH] Don't use the trace context in the prune_closed_jaxpr_outputs cache. This code only manipulates jaxprs, and does not trace anything. PiperOrigin-RevId: 648398046 --- jax/_src/interpreters/partial_eval.py | 2 +- jax/_src/util.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 497c9ea129a8..47fb411fa18d 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1521,7 +1521,7 @@ def prune_closed_jaxpr_outputs( ) -> ClosedJaxpr: return _prune_closed_jaxpr_outputs(jaxpr, tuple(used_outputs)) -@weakref_lru_cache +@partial(weakref_lru_cache, trace_context_in_key=False) def _prune_closed_jaxpr_outputs( jaxpr: ClosedJaxpr, used_outputs: tuple[bool, ...] ) -> ClosedJaxpr: diff --git a/jax/_src/util.py b/jax/_src/util.py index 93562b75e0d3..5fb4ea4d7bf9 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -317,7 +317,8 @@ def clear_all_caches(): memoize = cache(max_size=None) -def weakref_lru_cache(call: Callable, maxsize=2048): +def weakref_lru_cache(call: Callable, maxsize=2048, + trace_context_in_key: bool = True): """ Least recently used cache decorator with weakref support. @@ -326,7 +327,9 @@ def weakref_lru_cache(call: Callable, maxsize=2048): behave similar to `functools.lru_cache`. """ global _weakref_lru_caches - cached_call = xc.weakref_lru_cache(config.trace_context, call, maxsize) + cached_call = xc.weakref_lru_cache( + config.trace_context if trace_context_in_key else _ignore, + call, maxsize) _weakref_lru_caches.add(cached_call) return cached_call