Skip to content

Commit

Permalink
add non-advanced boolean indexing support
Browse files Browse the repository at this point in the history
also don't sub-sample indexing tests (run them all)
fixes jax-ml#166
  • Loading branch information
mattjj committed Dec 23, 2018
1 parent fd3645a commit 6d6b526
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 29 deletions.
40 changes: 31 additions & 9 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,7 @@ def _rewriting_take(arr, idx, axis=0):

# Handle slice index (only static, otherwise an error is raised)
elif isinstance(idx, slice):
if not _all(elt is None or isinstance(core.get_aval(elt), ConcreteArray)
if not _all(elt is None or type(core.get_aval(elt)) is ConcreteArray
for elt in (idx.start, idx.stop, idx.step)):
msg = ("Array slice indices must have static start/stop/step to be used "
"with Numpy indexing syntax. Try lax.dynamic_slice instead.")
Expand All @@ -1448,6 +1448,27 @@ def _rewriting_take(arr, idx, axis=0):
result = lax.slice_in_dim(arr, start, limit, stride, axis=axis)
return lax.rev(result, [axis]) if needs_rev else result

# Handle non-advanced bool index (only static, otherwise an error is raised)
elif (isinstance(abstract_idx, ShapedArray) and onp.issubdtype(abstract_idx.dtype, onp.bool_)
or isinstance(idx, list) and _all(not _shape(e) and onp.issubdtype(_dtype(e), onp.bool_)
for e in idx)):
if isinstance(idx, list):
idx = array(idx)
abstract_idx = core.get_aval(idx)

if not type(abstract_idx) is ConcreteArray:
msg = ("Array boolean indices must be static (e.g. no dependence on an "
"argument to a jit or vmap function).")
raise IndexError(msg)
else:
if idx.ndim > arr.ndim or idx.shape != arr.shape[:idx.ndim]:
msg = "Boolean index shape did not match indexed array shape prefix."
raise IndexError(msg)
else:
reshaped_arr = arr.reshape((-1,) + arr.shape[idx.ndim:])
int_idx, = onp.where(idx.ravel())
return lax.index_take(reshaped_arr, (int_idx,), (0,))

# Handle non-advanced tuple indices by recursing once
elif isinstance(idx, tuple) and _all(onp.ndim(elt) == 0 for elt in idx):
canonical_idx = _canonicalize_tuple_index(arr, idx)
Expand Down Expand Up @@ -1487,10 +1508,11 @@ def _rewriting_take(arr, idx, axis=0):
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing
elif _is_advanced_int_indexer(idx):
canonical_idx = _canonicalize_tuple_index(arr, tuple(idx))
idx_noadvanced = [slice(None) if _is_int(e) else e for e in canonical_idx]
idx_noadvanced = [slice(None) if _is_int_arraylike(e) else e
for e in canonical_idx]
arr_sliced = _rewriting_take(arr, tuple(idx_noadvanced))

advanced_pairs = ((e, i) for i, e in enumerate(canonical_idx) if _is_int(e))
advanced_pairs = ((e, i) for i, e in enumerate(canonical_idx) if _is_int_arraylike(e))
idx_advanced, axes = zip(*advanced_pairs)
idx_advanced = broadcast_arrays(*idx_advanced)

Expand Down Expand Up @@ -1522,11 +1544,11 @@ def _is_advanced_int_indexer(idx):
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
if isinstance(idx, (tuple, list)):
# We assume this check comes *after* the check for non-advanced tuple index,
# and hence we already know at least one element is a sequence
return _all(e is None or e is Ellipsis or isinstance(e, slice) or _is_int(e)
for e in idx)
# and hence we already know at least one element is a sequence if it's a tuple
return _all(e is None or e is Ellipsis or isinstance(e, slice)
or _is_int_arraylike(e) for e in idx)
else:
return _is_int(idx)
return _is_int_arraylike(idx)


def _is_advanced_int_indexer_without_slices(idx):
Expand All @@ -1539,11 +1561,11 @@ def _is_advanced_int_indexer_without_slices(idx):
return True


def _is_int(x):
def _is_int_arraylike(x):
"""Returns True if x is array-like with integer dtype, False otherwise."""
return (isinstance(x, int) and not isinstance(x, bool)
or onp.issubdtype(getattr(x, "dtype", None), onp.integer)
or isinstance(x, (list, tuple)) and _all(_is_int(e) for e in x))
or isinstance(x, (list, tuple)) and _all(_is_int_arraylike(e) for e in x))


def _canonicalize_tuple_index(arr, idx):
Expand Down
82 changes: 62 additions & 20 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,10 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None):
class IndexingTest(jtu.JaxTestCase):
"""Tests for Numpy indexing translation rules."""

@parameterized.named_parameters(jtu.cases_from_list({
"testcase_name":
"{}_inshape={}_indexer={}".format(
name, jtu.format_shape_dtype_string( shape, dtype), indexer),
"shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer
@parameterized.named_parameters({
"testcase_name": "{}_inshape={}_indexer={}".format(
name, jtu.format_shape_dtype_string( shape, dtype), indexer),
"shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer
} for name, index_specs in [
("OneIntIndex", [
IndexSpec(shape=(3,), indexer=1),
Expand Down Expand Up @@ -154,14 +153,14 @@ class IndexingTest(jtu.JaxTestCase):
IndexSpec(shape=(3, 4), indexer=()),
]),
] for shape, indexer in index_specs for dtype in all_dtypes
for rng in [jtu.rand_default()]))
for rng in [jtu.rand_default()])
@jtu.skip_on_devices("tpu")
def testStaticIndexing(self, shape, dtype, rng, indexer):
args_maker = lambda: [rng(shape, dtype)]
fun = lambda x: x[indexer]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)

@parameterized.named_parameters(jtu.cases_from_list({
@parameterized.named_parameters({
"testcase_name":
"{}_inshape={}_indexer={}".format(name,
jtu.format_shape_dtype_string(
Expand Down Expand Up @@ -233,7 +232,7 @@ def testStaticIndexing(self, shape, dtype, rng, indexer):
# IndexSpec(shape=(3, 4), indexer=()),
# ]),
] for shape, indexer in index_specs for dtype in float_dtypes
for rng in [jtu.rand_default()]))
for rng in [jtu.rand_default()])
@jtu.skip_on_devices("tpu")
def testStaticIndexingGrads(self, shape, dtype, rng, indexer):
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
Expand All @@ -257,7 +256,7 @@ def _ReplaceSlicesWithTuples(self, idx):
else:
return idx, lambda x: x

@parameterized.named_parameters(jtu.cases_from_list(
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
"shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer}
Expand All @@ -280,7 +279,7 @@ def _ReplaceSlicesWithTuples(self, idx):
]
for shape, indexer in index_specs
for dtype in all_dtypes
for rng in [jtu.rand_default()]))
for rng in [jtu.rand_default()])
def testDynamicIndexingWithSlicesErrors(self, shape, dtype, rng, indexer):
unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

Expand All @@ -292,7 +291,7 @@ def fun(x, unpacked_indexer):
args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
self.assertRaises(IndexError, lambda: fun(*args_maker()))

@parameterized.named_parameters(jtu.cases_from_list(
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
"shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer}
Expand All @@ -312,7 +311,7 @@ def fun(x, unpacked_indexer):
]
for shape, indexer in index_specs
for dtype in all_dtypes
for rng in [jtu.rand_default()]))
for rng in [jtu.rand_default()])
def testDynamicIndexingWithIntegers(self, shape, dtype, rng, indexer):
unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

Expand All @@ -324,7 +323,7 @@ def fun(x, unpacked_indexer):
self._CompileAndCheck(fun, args_maker, check_dtypes=True)

@skip
@parameterized.named_parameters(jtu.cases_from_list(
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
"shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer}
Expand All @@ -346,7 +345,7 @@ def fun(x, unpacked_indexer):
]
for shape, indexer in index_specs
for dtype in float_dtypes
for rng in [jtu.rand_default()]))
for rng in [jtu.rand_default()])
def DISABLED_testDynamicIndexingWithIntegersGrads(self, shape, dtype, rng, indexer):
# TODO(mattjj): re-enable (test works but for grad-of-compile, in flux)
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
Expand All @@ -360,7 +359,7 @@ def fun(unpacked_indexer, x):
arr = rng(shape, dtype)
check_grads(partial(fun, unpacked_indexer), (arr,), 2, tol, tol, tol)

@parameterized.named_parameters(jtu.cases_from_list(
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
"shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer}
Expand Down Expand Up @@ -412,13 +411,13 @@ def fun(unpacked_indexer, x):
]
for shape, indexer in index_specs
for dtype in all_dtypes
for rng in [jtu.rand_default()]))
for rng in [jtu.rand_default()])
def testAdvancedIntegerIndexing(self, shape, dtype, rng, indexer):
args_maker = lambda: [rng(shape, dtype), indexer]
fun = lambda x, idx: x[idx]
self._CompileAndCheck(fun, args_maker, check_dtypes=True)

@parameterized.named_parameters(jtu.cases_from_list(
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
"shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer}
Expand Down Expand Up @@ -470,14 +469,14 @@ def testAdvancedIntegerIndexing(self, shape, dtype, rng, indexer):
]
for shape, indexer in index_specs
for dtype in float_dtypes
for rng in [jtu.rand_default()]))
for rng in [jtu.rand_default()])
def testAdvancedIntegerIndexingGrads(self, shape, dtype, rng, indexer):
tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
arg = rng(shape, dtype)
fun = lambda x: x[indexer]**2
check_grads(fun, (arg,), 2, tol, tol, tol)

@parameterized.named_parameters(jtu.cases_from_list(
@parameterized.named_parameters(
{"testcase_name": "{}_inshape={}_indexer={}"
.format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
"shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer}
Expand Down Expand Up @@ -533,7 +532,7 @@ def testAdvancedIntegerIndexingGrads(self, shape, dtype, rng, indexer):
]
for shape, indexer in index_specs
for dtype in all_dtypes
for rng in [jtu.rand_default()]))
for rng in [jtu.rand_default()])
def testMixedAdvancedIntegerIndexing(self, shape, dtype, rng, indexer):
indexer_with_dummies = [e if isinstance(e, onp.ndarray) else ()
for e in indexer]
Expand Down Expand Up @@ -588,6 +587,49 @@ def foo(x):

self.assertAllClose(a1, a2, check_dtypes=True)

def testBooleanIndexingArray1D(self):
idx = onp.array([True, True, False])
x = api.device_put(onp.arange(3))
ans = x[idx]
expected = onp.arange(3)[idx]
self.assertAllClose(ans, expected, check_dtypes=False)

def testBooleanIndexingList1D(self):
idx = [True, True, False]
x = api.device_put(onp.arange(3))
ans = x[idx]
expected = onp.arange(3)[idx]
self.assertAllClose(ans, expected, check_dtypes=False)

def testBooleanIndexingArray2DBroadcast(self):
idx = onp.array([True, True, False, True])
x = onp.arange(8).reshape(4, 2)
ans = api.device_put(x)[idx]
expected = x[idx]
self.assertAllClose(ans, expected, check_dtypes=False)

def testBooleanIndexingList2DBroadcast(self):
idx = [True, True, False, True]
x = onp.arange(8).reshape(4, 2)
ans = api.device_put(x)[idx]
expected = x[idx]
self.assertAllClose(ans, expected, check_dtypes=False)

def testBooleanIndexingArray2D(self):
idx = onp.array([[True, False],
[False, True],
[False, False],
[True, True]])
x = onp.arange(8).reshape(4, 2)
ans = api.device_put(x)[idx]
expected = x[idx]
self.assertAllClose(ans, expected, check_dtypes=False)

def testBooleanIndexingDynamicShapeError(self):
x = onp.zeros(3)
i = onp.array([True, True, False])
self.assertRaises(IndexError, lambda: api.jit(lambda x, i: x[i])(x, i))


if __name__ == "__main__":
absltest.main()

0 comments on commit 6d6b526

Please sign in to comment.