From 1627bc6e8366d336e95c283c91d1e92e5e21f608 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 3 Nov 2022 11:40:51 -0700 Subject: [PATCH] generate dynamic_slice rather than slice for simple indexing/slicing --- jax/_src/numpy/lax_numpy.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6e12e84fe815..0087d52913ea 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3776,10 +3776,9 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)): if 0 <= idx < arr.shape[0]: - if _any(isinstance(d, core.Tracer) for d in arr.shape[1:]): - return lax.dynamic_index_in_dim(arr, idx, keepdims=False) - else: - return lax.index_in_dim(arr, idx, keepdims=False) + # Use dynamic rather than static index here to avoid slow repeated execution: + # See https://github.com/google/jax/issues/12198 + return lax.dynamic_index_in_dim(arr, idx, keepdims=False) if (arr.ndim > 0 and isinstance(arr.shape[0], int) and isinstance(idx, slice) and (type(idx.start) is int or idx.start is None) and @@ -3794,6 +3793,10 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False, if _any(isinstance(d, core.Tracer) for d in arr.shape[1:]): if step == 1: # TODO(mattjj, sharadmv): handle step != 1 return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0) + elif step == 1: + # Use dynamic rather than static slice here to avoid slow repeated execution: + # See https://github.com/google/jax/issues/12198 + return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0) else: return lax.slice_in_dim(arr, start, stop, step)