Skip to content

Commit

Permalink
Add a no-op batching rule for optimization_barrier_p
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704507586
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Dec 10, 2024
1 parent 1743f2c commit 944d822
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
4 changes: 4 additions & 0 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6496,3 +6496,7 @@ def _optimization_barrier_lowering_rule(ctx, *args):
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
mlir.register_lowering(optimization_barrier_p,
_optimization_barrier_lowering_rule)

def _optimization_barrier_batcher(batched_args, batch_dims, **params):
return optimization_barrier_p.bind(*batched_args, **params), batch_dims
batching.primitive_batchers[optimization_barrier_p] = _optimization_barrier_batcher
2 changes: 1 addition & 1 deletion tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3551,7 +3551,7 @@ def testAsarray(self, typ):
with jax.transfer_guard('disallow'):
jax.jit(asarray_closure)()

def testOptimizationBarrier(self):
def test_optimization_barrier(self):
x = lax.optimization_barrier((2, 3))
self.assertEqual((2, 3), x)

Expand Down
19 changes: 19 additions & 0 deletions tests/lax_vmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,25 @@ def testTopK(self, shape, dtype, k, bdims):
op2 = lambda x: lax.top_k(x, k=k)[1]
self._CheckBatching(op2, 5, bdims, (shape,), (dtype,), rng)

@jtu.sample_product(
[dict(shape=shape, bdims=bdims)
for shape in [(8,), (3, 4, 5)]
for bdims in lax_test_util.all_bdims(shape)],
dtype=lax_test_util.default_dtypes,
)
def test_optimization_barrier_vmap(self, shape, dtype, bdims):
rng = jtu.rand_small(self.rng())
self._CheckBatching(lax.optimization_barrier, 5, bdims, (shape,), (dtype,),
rng)

def test_optimization_barrier_vmap_out_axes(self):
x = jnp.arange(8)
y = x.reshape(1, 8)
out = jax.vmap(lax.optimization_barrier, in_axes=((0, 1),),
out_axes=(0, 1))((x, y))
self.assertArraysEqual(out[0], x)
self.assertArraysEqual(out[1], y)

@jtu.sample_product(
[dict(shape=shape, bdims=bdims, dimension=dimension, arity=arity)
for shape in [(2, 3)]
Expand Down

0 comments on commit 944d822

Please sign in to comment.