From 0f1520b6d20a075fce3568e114edcc04f8881470 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 13 Apr 2021 09:09:10 -0400 Subject: [PATCH] Enable variadic select_and_gather on TPU. --- jax/_src/lax/lax.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index c08c164e59cd..c11e4926891d 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -5732,7 +5732,7 @@ def reducer(): return snd(out) # TODO(phawkins): use this translation rule on all platforms. -def _select_and_gather_add_translation_using_varadic_reducewindow( +def _select_and_gather_add_translation_using_variadic_reducewindow( c, tangents, operand, *, select_prim, window_dimensions, window_strides, padding, base_dilation, window_dilation): shape = c.get_shape(operand) @@ -5827,13 +5827,12 @@ def _select_and_gather_add_batching_rule( batching.primitive_batchers[select_and_gather_add_p] = \ _select_and_gather_add_batching_rule # TODO(b/183233858): use variadic reducewindow on GPU, when implemented. -# TODO(b/184942267): use variadic reducewindow on TPU, when fixed. if jax.lib._xla_extension_version >= 15: xla.backend_specific_translations['cpu'][select_and_gather_add_p] = \ - _select_and_gather_add_translation_using_varadic_reducewindow -xla.backend_specific_translations['tpu'][select_and_gather_add_p] = partial( - _select_and_gather_add_translation, - max_bits=32) + _select_and_gather_add_translation_using_variadic_reducewindow +xla.backend_specific_translations['tpu'][select_and_gather_add_p] = \ + _select_and_gather_add_translation_using_variadic_reducewindow + def _sort_abstract_eval(*args, **kwargs): args = tuple(raise_to_shaped(arg) for arg in args)