diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 4ac6ff83f4aa..d269953b2456 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -81,13 +81,11 @@ def pallas_call_tpu_lowering_rule( dimension_semantics=dimension_semantics, mesh=mesh) if debug: print(mosaic_module) - if extra_args and input_output_aliases: - raise NotImplementedError( - "Cannot use both input_output_aliases and extra_args." - ) + num_extra_args = len(extra_args) num_dyn_bounds = grid_mapping.num_dynamic_grid_bounds input_output_aliases = tuple( - (a[0] + num_dyn_bounds, a[1]) for a in input_output_aliases + (a[0] + num_dyn_bounds + num_extra_args, a[1]) + for a in input_output_aliases ) out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes] def _lower_fun(*args):