Skip to content

Commit

Permalink
[Pallas] Add stride in Pallas dynamic slice and support strided load/…
Browse files Browse the repository at this point in the history
…store.

PiperOrigin-RevId: 615940113
  • Loading branch information
bythew3i authored and jax authors committed Mar 14, 2024
1 parent 1cef1d9 commit 2048e3c
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 39 deletions.
100 changes: 72 additions & 28 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
from jax._src.util import safe_zip
from jax._src.util import split_list
from jax._src.util import unzip2
from jax._src.util import unzip3
from jax.experimental.mosaic.dialects import tpu
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -746,47 +745,71 @@ def _maybe_cast_to_index(cast_to_index, x):
return _make_index(x)
return _ensure_mlir_value(x, aval=jax_core.ShapedArray((), jnp.int32))

def _index_to_start_size(idx: tuple[indexing.Slice | int | ir.Value, ...],
cast_to_index: bool) -> tuple[ir.Value, int, bool]:

def _index_to_start_size_stride(
idx: tuple[indexing.Slice | int | ir.Value, ...], cast_to_index: bool
) -> tuple[ir.Value, int, int, bool]:
assert not isinstance(idx, slice)
if isinstance(idx, indexing.Slice):
start = _maybe_cast_to_index(cast_to_index, idx.start)
size = idx.size
stride = idx.stride
squeeze = False
elif isinstance(idx, int):
start = _maybe_cast_to_index(cast_to_index, idx)
size = 1
stride = 1
squeeze = True
else:
if np.shape(idx):
raise ValueError(f"Can only use ()-shaped and slice indexing: {idx}")
start = _maybe_cast_to_index(cast_to_index, idx)
size = 1
stride = 1
squeeze = True
return start, size, squeeze
return start, size, stride, squeeze


def _indexer_to_start_size(
indexer: NDIndexer, ref_block_shape: tuple[int | pl_core.Mapped, ...], *,
def _indexer_to_start_size_stride(
indexer: NDIndexer,
ref_block_shape: tuple[int | pl_core.Mapped, ...],
*,
cast_to_index: bool,
) -> tuple[tuple[ir.Value, ...], tuple[int, ...], tuple[bool, ...],
tuple[int | pl_core.Mapped, ...]]:
) -> tuple[
tuple[ir.Value, ...],
tuple[int, ...],
tuple[int, ...],
tuple[bool, ...],
tuple[int | pl_core.Mapped, ...],
]:
indices_iter = iter(indexer.indices)
starts, sizes, squeeze_dims = unzip3(
(
_maybe_cast_to_index(cast_to_index, 0),
1,
True,
)
if s is pl_core.mapped
else _index_to_start_size(next(indices_iter), cast_to_index)
for s in ref_block_shape
)
starts, sizes, strides, squeeze_dims = [], [], [], []
for s in ref_block_shape:
start, size, stride, squeeze_dim = (
(
_maybe_cast_to_index(cast_to_index, 0),
1,
1,
True,
)
if s is pl_core.mapped
else _index_to_start_size_stride(next(indices_iter), cast_to_index)
)
starts.append(start)
sizes.append(size)
strides.append(stride)
squeeze_dims.append(squeeze_dim)
next_index = next(indices_iter, None)
assert next_index is None, (indexer.indices, ref_block_shape)
new_ref_block_shape = tuple(s for s, squeeze in zip(sizes, squeeze_dims)
if not squeeze)
return tuple(starts), tuple(sizes), tuple(squeeze_dims), new_ref_block_shape
return (
tuple(starts),
tuple(sizes),
tuple(strides),
tuple(squeeze_dims),
new_ref_block_shape,
)


def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
Expand All @@ -796,9 +819,15 @@ def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
tuple[int | pl_core.Mapped, ...]]:
assert ref_block_shape is not None
target_shape = indexer.get_indexer_shape()
starts, sizes, squeeze_dims, ref_block_shape = _indexer_to_start_size(
indexer, ref_block_shape, cast_to_index=False,
starts, sizes, strides, squeeze_dims, ref_block_shape = (
_indexer_to_start_size_stride(
indexer,
ref_block_shape,
cast_to_index=False,
)
)
if not all((s is None or s == 1) for s in strides):
raise NotImplementedError("Strided slices of references are unsupported.")
target_ref_ty = ir.MemRefType.get(
tuple(sizes), _dtype_to_ir_type(ref_aval.dtype),
memory_space=ref.type.memory_space)
Expand Down Expand Up @@ -846,14 +875,21 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
for a in idx_aval.indices
):
raise ValueError("Cannot do int indexing on TPU")
starts, sizes, _, _ = _indexer_to_start_size(
idx, ref_block_shape, cast_to_index=True,
starts, sizes, strides, _, _ = _indexer_to_start_size_stride(
idx,
ref_block_shape,
cast_to_index=True,
)
need_stride = not all((s is None or s == 1) for s in strides)
load_aval = jax_core.ShapedArray(sizes, dtype=ref_aval.dtype)
if is_smem_load:
if ctx.avals_out[0].shape:
raise ValueError("Can only load scalars from SMEM")
return memref.LoadOp(ref, starts).result
if need_stride:
load_val = tpu.StridedLoadOp(
aval_to_ir_type(load_aval), ref, starts, strides
).result
else:
load_val = vector.LoadOp(aval_to_ir_type(load_aval), ref, starts).result
if load_aval == aval_out:
Expand Down Expand Up @@ -896,10 +932,12 @@ def _masked_swap_lowering_rule(
raise NotImplementedError(
"Indexing into a ()-shaped Ref not yet supported on TPU.")

starts, _, _, _ = _indexer_to_start_size(
idx, ref_block_shape, cast_to_index=True,
starts, _, strides, _, _ = _indexer_to_start_size_stride(
idx,
ref_block_shape,
cast_to_index=True,
)

need_stride = not all((s is None or s == 1) for s in strides)
if is_smem_store:
if val_aval.shape:
raise ValueError("Can only store scalars to SMEM")
Expand All @@ -918,7 +956,10 @@ def _masked_swap_lowering_rule(
mem_aval = aval_out.update(shape=tuple(mem_slice_shape))
mem_aval_vec_type = ir.VectorType.get(mem_aval.shape,
_dtype_to_ir_type(mem_aval.dtype))
result = vector.LoadOp(mem_aval_vec_type, ref, starts).result
if need_stride:
result = tpu.StridedLoadOp(mem_aval_vec_type, ref, starts, strides).result
else:
result = vector.LoadOp(mem_aval_vec_type, ref, starts).result
if mem_aval != aval_out:
# We are slicing a scalar so provided dummy 1 indices
result_vec_type = ir.VectorType.get(aval_out.shape,
Expand All @@ -927,7 +968,10 @@ def _masked_swap_lowering_rule(
val_vec_type = ir.VectorType.get(mem_aval.shape,
_dtype_to_ir_type(mem_aval.dtype))
val = vector.ShapeCastOp(val_vec_type, val).result
vector.StoreOp(val, ref, starts)
if need_stride:
tpu.StridedStoreOp(val, ref, starts, strides)
else:
vector.StoreOp(val, ref, starts)
return result


Expand Down
35 changes: 24 additions & 11 deletions jax/_src/state/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,19 @@ class Slice:
"""Represents a slice with a dynamic start index and a fixed size."""
start: Any
size: int
stride: int = 1

def __post_init__(self):
if self.size < 0:
raise ValueError("`size` must not be negative.")
if self.stride < 1:
raise ValueError("`stride` must be >= 1.")

def tree_flatten(self):
# If `start` is statically known, we treat it as static information
if isinstance(self.start, int):
return (), (self.start, self.size)
return (self.start,), (self.size,)
return (), (self.start, self.size, self.stride)
return (self.start,), (self.size, self.stride)

@classmethod
def tree_unflatten(cls, aux_data, children) -> Slice:
Expand All @@ -51,21 +54,30 @@ def tree_unflatten(cls, aux_data, children) -> Slice:
@classmethod
def from_slice(cls, slc: slice, size: int) -> Slice:
start, stop, step = slc.indices(size)
if step != 1:
raise ValueError(f"slice must have a step of 1 (found: {step})")
return cls(start, max(stop - start, 0))
if step < 1:
raise ValueError(f"slice must have a step >= 1 (found: {step})")
return cls(start, max((stop - start + step - 1) // step, 0), step)


def dslice(start: int | Array | None, size: int | None = None
) -> slice | Slice:
def dslice(
start: int | Array | None,
size: int | None = None,
stride: int | None = None,
) -> slice | Slice:
"""Constructs a `Slice` from a start and a size."""
if start is None:
return slice(None)
if stride is None:
stride = 1
if not isinstance(stride, int):
raise ValueError("Non-static stride in `dslice`")
if size is None:
if not isinstance(start, int):
raise ValueError("Non-static `dslice`")
return Slice(0, start)
return Slice(start, size)
return Slice(0, start, stride)
return Slice(start, size, stride)


ds = dslice # Handy alias


Expand Down Expand Up @@ -113,9 +125,10 @@ def __post_init__(self):
if value := _maybe_concretize(start):
if value >= s:
raise ValueError(f"Out of bound slice: start={value}, dim={s}.")
if value + idx.size > s:
if value + (idx.size - 1) * idx.stride >= s:
raise ValueError(
f"Out of bound slice: start={value}, size={idx.size}, dim={s}."
f"Out of bound slice: start={value}, size={idx.size},"
f" stride={idx.stride}, dim={s}."
)
continue
# The shape of indexer integers should be broadcastable up to the
Expand Down

0 comments on commit 2048e3c

Please sign in to comment.