Skip to content

Commit

Permalink
defer to custom eltype for slice lowering rule
Browse files Browse the repository at this point in the history
We already handled dynamic slice, but plain slice is eltype-polymorphic too.
  • Loading branch information
froystig committed Aug 10, 2022
1 parent 8a1b478 commit 7955799
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
4 changes: 4 additions & 0 deletions jax/_src/lax/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,10 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices,

def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
strides = strides or [1] * len(start_indices)
aval_out, = ctx.avals_out
if type(aval_out.dtype) in core.custom_eltypes:
return aval_out.dtype.slice_mlir(ctx, x, start_indices, limit_indices,
strides)
return mhlo.SliceOp(x,
mlir.dense_int_elements(start_indices),
mlir.dense_int_elements(limit_indices),
Expand Down
22 changes: 22 additions & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3007,6 +3007,16 @@ def handler(_, buf):
def empty_mlir(ctx):
return mlir.ir_constants(np.zeros((2,), dtype=np.dtype('uint32')))

@staticmethod
def slice_mlir(ctx, x, start_indices, limit_indices, strides):
start_indices = (*start_indices, 0)
limit_indices = (*limit_indices, 2)
strides = (*strides, 1)
return mhlo.SliceOp(x,
mlir.dense_int_elements(start_indices),
mlir.dense_int_elements(limit_indices),
mlir.dense_int_elements(strides)).results

@staticmethod
def dynamic_slice_mlir(ctx, x, start_indices, slice_sizes):
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
Expand Down Expand Up @@ -3251,6 +3261,18 @@ def test_vmap(self):
expected = jnp.broadcast_to(3 * 4 * 5, (3, 5, 4)).astype('float32')
self.assertAllClose(ys, expected)

def test_slice(self):
ks = jax.jit(lambda: make((3, 4)))()
ys = jax.jit(lambda x: lax.slice_in_dim(x, 1, 3))(ks)
self.assertIsInstance(ys, FooArray)
self.assertEqual(ys.shape, (2, 4))

def test_dynamic_slice(self):
ks = jax.jit(lambda: make((3, 4)))()
ys = jax.jit(lambda x, i: lax.dynamic_slice_in_dim(x, i, 2))(ks, 1)
self.assertIsInstance(ys, FooArray)
self.assertEqual(ys.shape, (2, 4))

def test_transpose(self):
ks = jax.jit(lambda: make((3, 4)))()
ys = jax.jit(lambda x: x.T)(ks)
Expand Down

0 comments on commit 7955799

Please sign in to comment.