Skip to content

Commit

Permalink
generate dynamic_slice rather than slice for simple indexing/slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Nov 3, 2022
1 parent 91d134d commit 1627bc6
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3776,10 +3776,9 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
if (arr.ndim > 0 and isinstance(idx, (int, np.integer)) and
not isinstance(idx, (bool, np.bool_)) and isinstance(arr.shape[0], int)):
if 0 <= idx < arr.shape[0]:
if _any(isinstance(d, core.Tracer) for d in arr.shape[1:]):
return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
else:
return lax.index_in_dim(arr, idx, keepdims=False)
# Use dynamic rather than static index here to avoid slow repeated execution:
# See https://github.com/google/jax/issues/12198
return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
if (arr.ndim > 0 and isinstance(arr.shape[0], int) and
isinstance(idx, slice) and
(type(idx.start) is int or idx.start is None) and
Expand All @@ -3794,6 +3793,10 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
if _any(isinstance(d, core.Tracer) for d in arr.shape[1:]):
if step == 1: # TODO(mattjj, sharadmv): handle step != 1
return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0)
elif step == 1:
# Use dynamic rather than static slice here to avoid slow repeated execution:
# See https://github.com/google/jax/issues/12198
return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0)
else:
return lax.slice_in_dim(arr, start, stop, step)

Expand Down

0 comments on commit 1627bc6

Please sign in to comment.