Skip to content

Commit

Permalink
[Pallas/GPU] Add support for jnp.cumsum
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 595578503
  • Loading branch information
sharadmv authored and jax authors committed Jan 4, 2024
1 parent 0d3fd53 commit 9112afb
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 0 deletions.
65 changes: 65 additions & 0 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,71 @@ def _atomic_lowering_rule(
triton_lowering_rules[primitives.atomic_rmw_p] = _atomic_lowering_rule


def _associative_scan_lowering(
body, ctx: TritonLoweringRuleContext, args, axes
):
flat_args = tree_util.tree_leaves(args)
(axis,) = axes
dtype = ctx.avals_in[0].dtype
in_avals = [
jax_core.ShapedArray((), dtype=dtype),
jax_core.ShapedArray((), dtype=dtype),
]
in_tree = tree_util.tree_structure((args, args))
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(body), in_tree
)
combine_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
flat_fun, in_avals
)
out_tree = out_tree_thunk()
del out_tree # Not needed
if consts:
raise NotImplementedError("Associative scan with constants not supported.")
element_types = [arg.type.scalar for arg in flat_args]
builder = ctx.builder
scan_op = builder.create_scan([t.handle for t in flat_args], axis)
region = scan_op.get_region(0)
old_ip = builder.get_insertion_point()
param_types = element_types * 2
ir_param_types = [ty.to_ir(builder) for ty in param_types]
block = builder.create_block_with_parent(region, ir_param_types)
combine_args = [
tl.core.tensor(block.arg(i), ty) for i, ty in enumerate(param_types)
]
results = lower_jaxpr_to_triton_ir(
ctx.context, combine_jaxpr, None, *combine_args
)
handles = [r.handle for r in results]
builder.create_scan_ret(*handles)
builder.restore_insertion_point(old_ip)
scan_op.verify()
def wrap_tensor(x, scalar_ty):
if ctx.avals_out[0].shape:
res_ty = tl.block_type(scalar_ty, ctx.avals_out[0].shape)
else:
# 0d-tensor -> scalar
res_ty = scalar_ty
return tl.core.tensor(x, res_ty)
results = [
wrap_tensor(scan_op.get_result(i), ty)
for i, ty in enumerate(element_types)
]
return results


def _cumsum_lowering_rule(
ctx: TritonLoweringRuleContext,
x,
*, axis: int, reverse: bool
):
if reverse:
raise NotImplementedError("Reverse cumsum is not supported.")
return _associative_scan_lowering(jnp.add, ctx, x, (axis,))[0]

triton_lowering_rules[lax.cumsum_p] = _cumsum_lowering_rule


_TRITON_FN_MAPPING = {
# Unary ops.
lax.neg_p: tl.semantic.minus,
Expand Down
31 changes: 31 additions & 0 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,37 @@ def reduce(x_ref, y_ref):
y_ref = op(x, axis=axis)
np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i)

@parameterized.named_parameters(*[
(f"{dtype}_{axis}", dtype, axis)
for axis in [0, 1]
for dtype in ["float16", "float32", "int32", "uint32"]
if isinstance(axis, int)
])
def test_cumsum(self, dtype, axis):
m, n = 32, 8
out_dtype = dtype
def make_x(key):
if jnp.issubdtype(dtype, jnp.integer):
return random.permutation(
key, jnp.arange(m * n, dtype=dtype), independent=True
).reshape(m, n)
else:
return random.normal(key, (m, n), dtype=dtype)
out_shape = jax.ShapeDtypeStruct((m, n), out_dtype)
grid = ()
@functools.partial(
self.pallas_call,
out_shape=out_shape,
grid=grid)
def reduce(x_ref, y_ref):
x = x_ref[...]
y_ref[...] = jnp.cumsum(x, axis=axis)
for i, key in enumerate(random.split(random.key(0), 20)):
x = make_x(key)
y = reduce(x)
y_ref = jnp.cumsum(x, axis=axis)
np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i)

def test_using_pallas_slice(self):
m, n = 32, 4
out_shape = jax.ShapeDtypeStruct((4, n), jnp.float32)
Expand Down

0 comments on commit 9112afb

Please sign in to comment.