Skip to content

Commit

Permalink
[pallas] Add support for cross-platform lowering
Browse files Browse the repository at this point in the history
When implementing this I have discovered that the
multi-platform lowering support does not handle the case when
the lowering rule for a platform invoke tracing (via `mlir.lower_fun`)
and that tracing encounters a primitive that has lowering rules
only for a particular platform. To support this, I have added
the `LoweringRuleContext.platforms` to override
`ModuleContext.platforms` with a potentially narrower set
of lowering platforms. Added a test for this scenario.
  • Loading branch information
gnecula committed Jun 12, 2024
1 parent 9b68873 commit 97db0e7
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 33 deletions.
23 changes: 18 additions & 5 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,8 @@ class ModuleContext:
ip: ir.InsertionPoint
symbol_table: ir.SymbolTable
backend_or_name: str | xb.XlaBackend | None
# The lowering platforms for the module. Can be more than one only when
# exporting.
platforms: Sequence[str]
axis_context: AxisContext
keepalives: list[Any]
Expand Down Expand Up @@ -689,6 +691,9 @@ class LoweringRuleContext:
# module_context.shape_poly_state.dim_vars
dim_var_values: Sequence[ir.Value] = ()
compute_type: str | None = None
# Override module_context.platforms if not None. Used during multi-platform
# lowering, when in a scope with a subset of the module_context.platforms.
platforms: Sequence[str] | None = None

def set_tokens_out(self, tokens_out: TokenSet):
assert self.tokens_out is None, 'Should only set `tokens_out` once.'
Expand Down Expand Up @@ -1662,7 +1667,7 @@ def lower_per_platform(ctx: LoweringRuleContext,
rule_args: the args of the lowering rules.
rule_kwargs: the kwargs of the lowering rules.
"""
platforms: Sequence[str] = ctx.module_context.platforms
platforms: Sequence[str] = ctx.platforms or ctx.module_context.platforms
# Special case the common case (single-platform lowering)
if len(platforms) == 1:
rule = platform_rules.get(platforms[0], default_rule)
Expand Down Expand Up @@ -1723,7 +1728,10 @@ def lower_per_platform(ctx: LoweringRuleContext,
index=rule_idx_op,
num_branches=len(kept_rules))
for i, rule in enumerate(kept_rules):
inner_ctx = ctx.replace()
platforms_for_this_rule = [p
for p, rule_idx in platform_to_kept_rules_idx.items()
if rule_idx == i]
inner_ctx = ctx.replace(platforms=platforms_for_this_rule)
branch = case_op.regions[i].blocks.append()
with ir.InsertionPoint(branch):
output = rule(inner_ctx, *rule_args, **rule_kwargs)
Expand Down Expand Up @@ -1764,7 +1772,7 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
The returned function does not use `avals_out`, so callers may pass any value
as `avals_out`."""
def f_lowered(ctx, *args, **params):
def f_lowered(ctx: LoweringRuleContext, *args, **params):
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
wrapped_fun = lu.wrap_init(f, params)

Expand All @@ -1774,11 +1782,12 @@ def f_lowered(ctx, *args, **params):
# case, we need to form a jaxpr with leading binders for those axis size
# arguments (by computing an InputType and using trace_to_jaxpr_dynamic2),
# and we need to call jaxpr_subcomp with these arguments made explicit.
assert ctx.axis_size_env is not None
args = (*ctx.axis_size_env.values(), *args)
idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)}
i32_aval = core.ShapedArray((), np.dtype('int32'))
implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env)
explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape))
explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape)) # type: ignore
if type(a) is core.DShapedArray else a, True)
for a in ctx.avals_in]
wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args))
Expand All @@ -1787,8 +1796,12 @@ def f_lowered(ctx, *args, **params):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?

if ctx.platforms is not None:
sub_context = ctx.module_context.replace(platforms=ctx.platforms)
else:
sub_context = ctx.module_context
out, tokens = jaxpr_subcomp(
ctx.module_context, jaxpr, ctx.name_stack, ctx.tokens_in,
sub_context, jaxpr, ctx.name_stack, ctx.tokens_in,
_ir_consts(consts), *map(wrap_singleton_ir_values, args),
dim_var_values=ctx.dim_var_values)
ctx.set_tokens_out(tokens)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _lower_fun(*args):
return mosaic.as_tpu_kernel(
mosaic_module,
out_avals,
backend=ctx.module_context.backend,
backend="tpu",
kernel_name=name,
cost_estimate=mosaic_params.get("cost_estimate"),
vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"),
Expand Down
52 changes: 31 additions & 21 deletions jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,38 +720,48 @@ def _pallas_call_lowering(
impl = partial(_pallas_call_impl, **params, interpret=True)
return mlir.lower_fun(impl, multiple_results=True)(ctx, *in_nodes)

try:
[platform] = ctx.module_context.platforms
except ValueError:
raise ValueError(
"Can only lower pallas_call on a single platform."
) from None

if platform == "cpu":
def cpu_lowering(ctx: mlir.LoweringRuleContext,
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
**params):
raise ValueError("Only interpret mode is supported on CPU backend.")
elif platform == "cuda" or platform == "rocm":

def tpu_lowering(ctx: mlir.LoweringRuleContext,
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
**params):
try:
from jax._src.pallas.mosaic import pallas_call_registration
except ImportError:
raise _unsupported_lowering_error("tpu")
else:
return pallas_call_registration.pallas_call_tpu_lowering_rule(
ctx, *in_nodes, **params
)

def gpu_lowering(ctx: mlir.LoweringRuleContext,
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
**params):
try:
if _PALLAS_USE_MOSAIC_GPU.value:
from jax._src.pallas.mosaic_gpu import pallas_call_registration
else:
from jax._src.pallas.triton import pallas_call_registration # type: ignore
except ImportError:
pass
raise _unsupported_lowering_error("gpu")
else:
return pallas_call_registration.pallas_call_lowering(
ctx, *in_nodes, interpret=interpret, **params
)
elif platform == "tpu":
try:
from jax._src.pallas.mosaic import pallas_call_registration # type: ignore
except ImportError:
pass
else:
return pallas_call_registration.pallas_call_tpu_lowering_rule(
ctx, *in_nodes, interpret=interpret, **params
ctx, *in_nodes, **params
)

raise _unsupported_lowering_error(platform)
return mlir.lower_per_platform(ctx, "pallas_call",
dict(cpu=cpu_lowering,
tpu=tpu_lowering,
cuda=gpu_lowering,
rocm=gpu_lowering),
None, # default_rule
effects.no_effects,
*in_nodes,
interpret=interpret,
**params)


mlir.register_lowering(pallas_call_p, _pallas_call_lowering)
Expand Down
5 changes: 2 additions & 3 deletions jax/_src/pallas/triton/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,8 @@ def pallas_call_lowering(
)
triton_params = compiler_params.get("triton", compiler_params)
num_warps = triton_params.pop("num_warps", 4)
if len(ctx.module_context.platforms) > 1:
raise NotImplementedError("multi-platform lowering for Pallas kernels")
if ctx.module_context.platforms[0] == "rocm":
[lowering_platform] = ctx.platforms or ctx.module_context.platforms
if lowering_platform == "rocm":
num_stages = triton_params.pop("num_stages", 1)
else:
num_stages = triton_params.pop("num_stages", 3)
Expand Down
59 changes: 59 additions & 0 deletions tests/export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations

from collections.abc import Sequence
import contextlib
import dataclasses
import functools
Expand Down Expand Up @@ -1369,6 +1370,64 @@ def test_multi_platform_nested_inside_single_platform_export(self):
res2 = exp2.call(x)
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x))

def test_multi_platform_mlir_lower_fun_with_platform_specific_primitives(self):
# A primitive with multiple lowering rules, which themselves involve
# tracing primitives with per-platform rules, using mlir.lower_fun.
# This situation arises for Pallas lowering.
def times_n_lowering(n: int, ctx: mlir.LoweringRuleContext,
x: mlir.ir.Value) -> Sequence[mlir.ir.Value]:
# Lowering n * x
res = x
for i in range(n - 1):
res = mlir.hlo.AddOp(res, x)
return res.results

times_2 = core.Primitive("__testing_times_2") # x2 for cpu
times_2.def_abstract_eval(lambda x: x)
# Define lowering rules only for the relevant platforms, ensure there
# is no error about missing lowering rules
mlir.register_lowering(times_2, functools.partial(times_n_lowering, 2),
"cpu")

times_3 = core.Primitive("__testing_times_3") # x3 for cuda
times_3.def_abstract_eval(lambda x: x)
mlir.register_lowering(times_3, functools.partial(times_n_lowering, 3),
"cuda")

times_4 = core.Primitive("__testing_times_4") # x4 for tpu
times_4.def_abstract_eval(lambda x: x)
mlir.register_lowering(times_4, functools.partial(times_n_lowering, 4),
"tpu")

times_2_or_3 = core.Primitive("__testing_times_2_or_3") # x2 for cpu, x3 for cuda
times_2_or_3.def_abstract_eval(lambda x: x)
mlir.register_lowering(times_2_or_3,
mlir.lower_fun(times_2.bind,
multiple_results=False), "cpu")
mlir.register_lowering(times_2_or_3,
mlir.lower_fun(times_3.bind,
multiple_results=False), "cuda")

times_2_or_3_or_4 = core.Primitive("__testing_times_2_or_3_or_4") # x2 for cpu, x3 for cuda, x4 for tpu
times_2_or_3_or_4.def_abstract_eval(lambda x: x)
times_2_or_3_or_4_lowering_cpu_cuda = mlir.lower_fun(times_2_or_3.bind,
multiple_results=False)
for platform in ["cpu", "cuda"]:
mlir.register_lowering(times_2_or_3_or_4,
times_2_or_3_or_4_lowering_cpu_cuda,
platform)
mlir.register_lowering(times_2_or_3_or_4, mlir.lower_fun(times_4.bind,
multiple_results=False),
"tpu")

@jax.jit
def f(x):
return times_2_or_3_or_4.bind(x)
x = np.float32(42.)
exp = export.export(f, lowering_platforms=["cpu", "cuda", "tpu"])(x)
expected = x * np.float32(dict(cpu=2, gpu=3, tpu=4)[jtu.device_under_test()])
self.assertAllClose(exp.call(x), expected)

def test_multi_platform_and_poly(self):
if jtu.test_device_matches(["gpu"]):
# The export is not applicable to GPU
Expand Down
7 changes: 4 additions & 3 deletions tests/pallas/export_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
a = np.arange(8)
exp = export.export(
add_vectors,
# TODO(necula): Make this test work on GPU also
lowering_platforms=["tpu"],
lowering_platforms=["tpu", "cuda"],
)(a, a)

if jtu.device_under_test() == "tpu":
if (jtu.device_under_test() == "tpu" or
(jtu.device_under_test() == "gpu" and
jtu.is_cuda_compute_capability_at_least("8.0"))):
res = export.call(exp)(a, a)
self.assertAllClose(res, a + a)

Expand Down

0 comments on commit 97db0e7

Please sign in to comment.