Skip to content

Commit

Permalink
Rename count_pjit_cache_miss with count_pjit_cpp_cache_miss becau…
Browse files Browse the repository at this point in the history
…se it is confusing which cache the first function is taking about as pjit has many caches

PiperOrigin-RevId: 521559652
  • Loading branch information
yashk2810 authored and jax authors committed Apr 3, 2023
1 parent 6f2256a commit 78678ee
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def count_primitive_compiles():


@contextmanager
def count_pjit_cache_miss():
def count_pjit_cpp_cache_miss():
original_pjit_lower = pjit_lib._pjit_lower
count = [0]

Expand Down
20 changes: 10 additions & 10 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2099,7 +2099,7 @@ def test_single_device_pjit_cpp_dispatch(self):
inp_data = np.arange(math.prod(shape)).reshape(shape)

f = pjit(lambda x: x @ x.T, in_shardings=None, out_shardings=None)
with jtu.count_pjit_cache_miss() as count:
with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(10):
arr1 = jax.device_put(
inp_data, jax.sharding.NamedSharding(mesh, P('x')))
Expand All @@ -2114,7 +2114,7 @@ def test_single_device_add_single_compile(self):
b = jax.device_put(jnp.array([4, 5, 6], dtype=jnp.float32),
jax.devices()[0])

with jtu.count_pjit_cache_miss() as count:
with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(2):
f1(a, b)
self.assertEqual(count[0], 1)
Expand Down Expand Up @@ -2246,7 +2246,7 @@ def f(y, **kwargs):
self.assertEqual(kwargs, {'x': 'foo'})
return y * y

with jtu.count_pjit_cache_miss() as count:
with jtu.count_pjit_cpp_cache_miss() as count:
y = jnp.arange(8.)
f_names = pjit(f, static_argnames='x')
f_names(y, x='foo')
Expand Down Expand Up @@ -2278,7 +2278,7 @@ def test_pjit_different_default_device(self):
with jax.default_device(test_device):
f(1)

with jtu.count_pjit_cache_miss() as count:
with jtu.count_pjit_cpp_cache_miss() as count:
f(1)

with jax.default_device(system_default_device):
Expand All @@ -2291,7 +2291,7 @@ def test_pjit_different_default_device(self):
with jax.default_device(system_default_device):
f(1)

# The count here is 0 because before `count_pjit_cache_miss`, `f` was
# The count here is 0 because before `count_pjit_cpp_cache_miss`, `f` was
# called with `system_default_device` and `test_device` so it was added
# to the cache. Subsequent calls hit the C++ cache.
self.assertEqual(count[0], 0)
Expand Down Expand Up @@ -2705,7 +2705,7 @@ def f(x):

inp = jnp.arange(3.)

with jtu.count_pjit_cache_miss() as count:
with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(10):
pjit(f)(inp)
self.assertEqual(count[0], 1)
Expand All @@ -2714,24 +2714,24 @@ def test_pjit_no_global_cache_hit_axis_resources(self):
mesh = jtu.create_global_mesh((1,), ('x',))
s = NamedSharding(mesh, P('x'))

with jtu.count_pjit_cache_miss() as count:
with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(10):
pjit(lambda x: x * 2, in_shardings=s, out_shardings=s)(jnp.arange(8.0))
self.assertEqual(count[0], 10)

with jtu.count_pjit_cache_miss() as count:
with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(10):
pjit(lambda x: x * 2, device=jax.devices()[0])(jnp.arange(8.))
self.assertEqual(count[0], 10)

pf = pjit(lambda x: x * 2, in_shardings=s, out_shardings=s)
with jtu.count_pjit_cache_miss() as count:
with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(10):
pf(jnp.arange(8.))
self.assertEqual(count[0], 1)

pf1 = pjit(lambda x: x * 2, device=jax.devices()[0])
with jtu.count_pjit_cache_miss() as count:
with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(10):
pf1(jnp.arange(8.))
self.assertEqual(count[0], 1)
Expand Down

0 comments on commit 78678ee

Please sign in to comment.