Skip to content

Commit

Permalink
Add compute_on_context_manager to thread local jit state. This is to …
Browse files Browse the repository at this point in the history
…avoid getting false cache hits

PiperOrigin-RevId: 671507042
  • Loading branch information
yashk2810 authored and jax authors committed Sep 5, 2024
1 parent 4c8bed9 commit a144eb2
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 2 deletions.
2 changes: 1 addition & 1 deletion jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ pytype_strict_library(
pytype_strict_library(
name = "compute_on",
srcs = ["_src/compute_on.py"],
deps = [],
deps = [":config"],
)

pytype_strict_library(
Expand Down
5 changes: 5 additions & 0 deletions jax/_src/compute_on.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations
import threading
from contextlib import contextmanager
from jax._src import config


class ComputeOnContext(threading.local):
Expand All @@ -28,6 +29,8 @@ def __init__(self):
@contextmanager
def extend_compute_type(c_type: str):
compute_on_context.stack.append(c_type)
config.update_thread_local_jit_state(
compute_on_context_manager=tuple(compute_on_context.stack))
try:
if len(set(filter(lambda x: x is not None, set(compute_on_context.stack)))) > 1:
raise NotImplementedError(
Expand All @@ -36,6 +39,8 @@ def extend_compute_type(c_type: str):
yield compute_on_context.stack[-1]
finally:
compute_on_context.stack.pop()
config.update_thread_local_jit_state(
compute_on_context_manager=tuple(compute_on_context.stack))

def current_compute_type() -> str | None:
return compute_on_context.stack[-1] if compute_on_context.stack else None
Expand Down
7 changes: 6 additions & 1 deletion jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,16 @@ def trace_context():
tls = jax_jit.thread_local_state()
axis_env_state = ()
mesh_context_manager = ()
compute_on_context_manager = ()
context: Any = tls.extra_jit_context
if context and context.axis_env_state is not None:
axis_env_state = context.axis_env_state
if context and context.mesh_context_manager:
mesh_context_manager = context.mesh_context_manager
return (axis_env_state, mesh_context_manager, enable_x64.value,
if context and context.compute_on_context_manager:
compute_on_context_manager = context.compute_on_context_manager
return (axis_env_state, mesh_context_manager, compute_on_context_manager,
enable_x64.value,
numpy_rank_promotion.value, default_matmul_precision.value,
dynamic_shapes.value, numpy_dtype_promotion.value,
default_device.value, random_seed_offset.value,
Expand Down Expand Up @@ -853,6 +857,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
dynamic_trace_state: Any | None = None
axis_env_state: Hashable = ()
mesh_context_manager: Hashable = ()
compute_on_context_manager: Hashable = ()

# Values set by _StateContextManager context managers.
# CAUTION: these must be initialized to `None`! The state context manager
Expand Down
16 changes: 16 additions & 0 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,22 @@ def test_fn(x_in, y_in):
self.assertArraysEqual(x_out, x1 * x1)
self.assertArraysEqual(y_out, y1 + y1)

def test_compute_on_cache_miss(self):
@jax.jit
def f(x):
return x * 2

inp = jnp.arange(10)
with jtu.count_jit_tracing_cache_miss() as count:
with compute_on('device_host'):
f(inp)

with compute_on('device'):
f(inp)

# 2 for `f` and `2` for `mul` (compute type changes for `mul`)
self.assertEqual(count[0], 4)


class ActivationOffloadingTest(jtu.JaxTestCase):

Expand Down

0 comments on commit a144eb2

Please sign in to comment.