From a144eb234b7bd0286be2914c8c8d36ece3e564b5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 5 Sep 2024 14:15:33 -0700 Subject: [PATCH] Add compute_on_context_manager to thread local jit state. This is to avoid getting false cache hits PiperOrigin-RevId: 671507042 --- jax/BUILD | 2 +- jax/_src/compute_on.py | 5 +++++ jax/_src/config.py | 7 ++++++- tests/memories_test.py | 16 ++++++++++++++++ 4 files changed, 28 insertions(+), 2 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index c4a421362f37..4c622194941f 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -765,7 +765,7 @@ pytype_strict_library( pytype_strict_library( name = "compute_on", srcs = ["_src/compute_on.py"], - deps = [], + deps = [":config"], ) pytype_strict_library( diff --git a/jax/_src/compute_on.py b/jax/_src/compute_on.py index 25b2be78d287..4495d38f9da8 100644 --- a/jax/_src/compute_on.py +++ b/jax/_src/compute_on.py @@ -15,6 +15,7 @@ from __future__ import annotations import threading from contextlib import contextmanager +from jax._src import config class ComputeOnContext(threading.local): @@ -28,6 +29,8 @@ def __init__(self): @contextmanager def extend_compute_type(c_type: str): compute_on_context.stack.append(c_type) + config.update_thread_local_jit_state( + compute_on_context_manager=tuple(compute_on_context.stack)) try: if len(set(filter(lambda x: x is not None, set(compute_on_context.stack)))) > 1: raise NotImplementedError( @@ -36,6 +39,8 @@ def extend_compute_type(c_type: str): yield compute_on_context.stack[-1] finally: compute_on_context.stack.pop() + config.update_thread_local_jit_state( + compute_on_context_manager=tuple(compute_on_context.stack)) def current_compute_type() -> str | None: return compute_on_context.stack[-1] if compute_on_context.stack else None diff --git a/jax/_src/config.py b/jax/_src/config.py index a0b91e2ad6ed..646e487c5f1c 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -202,12 +202,16 @@ def trace_context(): tls = jax_jit.thread_local_state() axis_env_state = () mesh_context_manager = () + compute_on_context_manager = () context: Any = tls.extra_jit_context if context and context.axis_env_state is not None: axis_env_state = context.axis_env_state if context and context.mesh_context_manager: mesh_context_manager = context.mesh_context_manager - return (axis_env_state, mesh_context_manager, enable_x64.value, + if context and context.compute_on_context_manager: + compute_on_context_manager = context.compute_on_context_manager + return (axis_env_state, mesh_context_manager, compute_on_context_manager, + enable_x64.value, numpy_rank_promotion.value, default_matmul_precision.value, dynamic_shapes.value, numpy_dtype_promotion.value, default_device.value, random_seed_offset.value, @@ -853,6 +857,7 @@ class _ThreadLocalExtraJitContext(NamedTuple): dynamic_trace_state: Any | None = None axis_env_state: Hashable = () mesh_context_manager: Hashable = () + compute_on_context_manager: Hashable = () # Values set by _StateContextManager context managers. # CAUTION: these must be initialized to `None`! The state context manager diff --git a/tests/memories_test.py b/tests/memories_test.py index affe5de99644..68aecfdf669f 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1420,6 +1420,22 @@ def test_fn(x_in, y_in): self.assertArraysEqual(x_out, x1 * x1) self.assertArraysEqual(y_out, y1 + y1) + def test_compute_on_cache_miss(self): + @jax.jit + def f(x): + return x * 2 + + inp = jnp.arange(10) + with jtu.count_jit_tracing_cache_miss() as count: + with compute_on('device_host'): + f(inp) + + with compute_on('device'): + f(inp) + + # 2 for `f` and `2` for `mul` (compute type changes for `mul`) + self.assertEqual(count[0], 4) + class ActivationOffloadingTest(jtu.JaxTestCase):