Skip to content

Commit

Permalink
Enable variadic select_and_gather on TPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp committed Apr 13, 2021
1 parent 85637c7 commit 0f1520b
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5732,7 +5732,7 @@ def reducer():
return snd(out)

# TODO(phawkins): use this translation rule on all platforms.
def _select_and_gather_add_translation_using_varadic_reducewindow(
def _select_and_gather_add_translation_using_variadic_reducewindow(
c, tangents, operand, *, select_prim, window_dimensions, window_strides,
padding, base_dilation, window_dilation):
shape = c.get_shape(operand)
Expand Down Expand Up @@ -5827,13 +5827,12 @@ def _select_and_gather_add_batching_rule(
batching.primitive_batchers[select_and_gather_add_p] = \
_select_and_gather_add_batching_rule
# TODO(b/183233858): use variadic reducewindow on GPU, when implemented.
# TODO(b/184942267): use variadic reducewindow on TPU, when fixed.
if jax.lib._xla_extension_version >= 15:
xla.backend_specific_translations['cpu'][select_and_gather_add_p] = \
_select_and_gather_add_translation_using_varadic_reducewindow
xla.backend_specific_translations['tpu'][select_and_gather_add_p] = partial(
_select_and_gather_add_translation,
max_bits=32)
_select_and_gather_add_translation_using_variadic_reducewindow
xla.backend_specific_translations['tpu'][select_and_gather_add_p] = \
_select_and_gather_add_translation_using_variadic_reducewindow


def _sort_abstract_eval(*args, **kwargs):
args = tuple(raise_to_shaped(arg) for arg in args)
Expand Down

0 comments on commit 0f1520b

Please sign in to comment.