Skip to content

Commit

Permalink
Enable extra args with input output aliasing
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 619041158
  • Loading branch information
sharadmv authored and jax authors committed Mar 26, 2024
1 parent 69980a2 commit f93c320
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions jax/_src/pallas/mosaic/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,11 @@ def pallas_call_tpu_lowering_rule(
dimension_semantics=dimension_semantics, mesh=mesh)
if debug:
print(mosaic_module)
if extra_args and input_output_aliases:
raise NotImplementedError(
"Cannot use both input_output_aliases and extra_args."
)
num_extra_args = len(extra_args)
num_dyn_bounds = grid_mapping.num_dynamic_grid_bounds
input_output_aliases = tuple(
(a[0] + num_dyn_bounds, a[1]) for a in input_output_aliases
(a[0] + num_dyn_bounds + num_extra_args, a[1])
for a in input_output_aliases
)
out_avals = [jax_core.ShapedArray(s.shape, s.dtype) for s in out_shapes]
def _lower_fun(*args):
Expand Down

0 comments on commit f93c320

Please sign in to comment.