Skip to content

Commit

Permalink
[sparse] implement sparse rule for lax.reshape_p
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 2, 2022
1 parent 29a828a commit 1a9a796
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 0 deletions.
1 change: 1 addition & 0 deletions jax/experimental/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@
bcoo_multiply_dense as bcoo_multiply_dense,
bcoo_multiply_sparse as bcoo_multiply_sparse,
bcoo_reduce_sum as bcoo_reduce_sum,
bcoo_reshape as bcoo_reshape,
bcoo_sort_indices as bcoo_sort_indices,
bcoo_sort_indices_p as bcoo_sort_indices_p,
bcoo_spdot_general_p as bcoo_spdot_general_p,
Expand Down
57 changes: 57 additions & 0 deletions jax/experimental/sparse/bcoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1615,6 +1615,63 @@ def bcoo_concatenate(operands, *, dimension):
return BCOO((new_data, new_indices), shape=out_aval.shape)


def bcoo_reshape(mat, *, new_sizes, dimensions):
"""Sparse implementation of {func}`jax.lax.reshape`.
Args:
operand: BCOO array to be reshaped.
new_sizes: sequence of integers specifying the resulting shape. The size
of the final array must match the size of the input. This must be specified
such that batch, sparse, and dense dimensions do not mix.
dimensions: optional sequence of integers specifying the permutation order of
the input shape. If specified, the length must match ``operand.shape``.
Additionally, dimensions must only permute among like dimensions of mat:
batch, sparse, and dense dimensions cannot be permuted.
Returns:
out: reshaped array.
"""
if mat.n_dense:
# TODO(jakevdp): implement reshape of dense dimensions.
raise NotImplementedError("bcoo_reshape for matrices with dense dimensions.")

if mat.n_batch:
batch_size = np.prod(mat.shape[:mat.n_batch])
cuml_shape = np.cumprod(new_sizes)
if batch_size not in cuml_shape:
raise ValueError("bcoo_reshape: new shape cannot mix batch and sparse dimensions; "
f"got shape={mat.shape} new_shape={new_sizes} with n_batch={mat.n_batch}")
ind = cuml_shape.searchsorted(batch_size, side='right')
else:
ind = 0
batch_sizes, sparse_sizes = new_sizes[:ind], new_sizes[ind:]
batch_perm, sparse_perm, _ = _validate_permutation(mat.data, mat.indices, dimensions or tuple(range(mat.ndim)), mat.shape)

if (mat.indices.shape[:mat.n_batch] != mat.data.shape[:mat.n_batch] != mat.shape[:mat.n_batch]):
# TODO(jakevdp) implement this case via broadcast_in_dim
raise NotImplementedError("reshape of arrays with broadacsted batch dimensions.")

# Reshape batch dimensions: this is accomplished via a standard reshape.
data = lax.reshape(
mat.data, new_sizes=(*batch_sizes, *mat.data.shape[mat.n_batch:]),
dimensions=(*batch_perm, *range(mat.n_batch, mat.data.ndim)))
indices = lax.reshape(
mat.indices, new_sizes=(*batch_sizes, *mat.indices.shape[mat.n_batch:]),
dimensions=(*batch_perm, *range(mat.n_batch, mat.indices.ndim)))

# Reshape the sparse dimensions: this is accomplished by re-indexing.
index_cols = tuple(indices[..., i] for i in sparse_perm)
sparse_shape = tuple(mat.shape[mat.n_batch + i] for i in sparse_perm)
flat_indices = jnp.ravel_multi_index(index_cols, dims=sparse_shape, mode='clip')
new_index_cols = jnp.unravel_index(flat_indices, sparse_sizes)
new_indices = jnp.concatenate([col[..., None] for col in new_index_cols], axis=-1)
with jax.numpy_rank_promotion('allow'):
oob_indices = (indices >= jnp.array(mat.shape[mat.n_batch:])).any(-1)
new_indices = new_indices.at[oob_indices].set(jnp.array(sparse_sizes))

return BCOO((data, new_indices), shape=new_sizes)


def _tuple_replace(tup, ind, val):
return tuple(val if i == ind else t for i, t in enumerate(tup))

Expand Down
12 changes: 12 additions & 0 deletions jax/experimental/sparse/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,13 @@ def _squeeze_sparse(spenv, *spvalues, dimensions):

sparse_rules[lax.squeeze_p] = _squeeze_sparse

def _reshape_sparse(spenv, *spvalues, new_sizes, dimensions):
operand, = spvalues_to_arrays(spenv, spvalues)
result = sparse.bcoo_reshape(operand, new_sizes=new_sizes, dimensions=dimensions)
return arrays_to_spvalues(spenv, (result,))

sparse_rules[lax.reshape_p] = _reshape_sparse

def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
# TODO(jakevdp): currently this approach discards all information about
# shared data & indices when generating the sparsified jaxpr. The
Expand Down Expand Up @@ -736,6 +743,10 @@ def _sum(self, *args, **kwargs):
"""Sum array along axis."""
return sparsify(lambda x: x.sum(*args, **kwargs))(self)

def _reshape(self, *args, **kwargs):
"""Sum array along axis."""
return sparsify(lambda x: x.reshape(*args, **kwargs))(self)

def _sparse_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
mode=None, fill_value=None):
# mirrors lax_numpy._rewriting_take.
Expand All @@ -751,6 +762,7 @@ def _sparse_rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=Fa
_swap_args = lambda f: lambda a, b: f(b, a)

_bcoo_methods = {
'reshape': _reshape,
'sum': _sum,
"__neg__": sparsify(jnp.negative),
"__pos__": sparsify(jnp.positive),
Expand Down
8 changes: 8 additions & 0 deletions tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,6 +1761,14 @@ def test_bcoo_reduce_sum(self, shape, dtype, n_batch, n_dense, axes):
tol = {np.float32: 1E-6, np.float64: 1E-14}
self.assertAllClose(result_dense, result_sparse, atol=tol, rtol=tol)

def test_bcoo_reshape_error(self):
x = sparse.BCOO.fromdense(jnp.ones((2, 2, 3)), n_batch=1)
with self.assertRaisesRegex(ValueError, ".*cannot mix batch and sparse dimensions.*"):
x.reshape(3, 2, 2)
y = sparse.BCOO((x.data[:1], x.indices), shape=x.shape)
with self.assertRaisesRegex(NotImplementedError, "reshape of arrays with broadacsted batch dimensions."):
y.reshape(2, 3, 2)

@unittest.skipIf(jtu.device_under_test() == "tpu", "TPU has insufficient precision")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_{}".format(
Expand Down
51 changes: 51 additions & 0 deletions tests/sparsify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,57 @@ def testSparseConcatenate(self, shapes, func, n_batch):
sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs]
self.assertArraysEqual(f(arrs), f(sparrs).todense())

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{shape}->{new_shape}_n_batch={n_batch}_n_dense={n_dense}",
"shape": shape, "new_shape": new_shape, "n_batch": n_batch, "n_dense": n_dense}
for shape, new_shape, n_batch, n_dense in [
[(6,), (2, 3), 0, 0],
[(1, 4), (2, 2), 0, 0],
[(12, 2), (2, 3, 4), 0, 0],
[(1, 3, 2), (2, 3), 0, 0],
[(1, 6), (2, 3, 1), 0, 0],
[(2, 3, 4), (3, 8), 0, 0],
[(2, 3, 4), (1, 2, 12), 1, 0],
[(2, 3, 4), (6, 2, 2), 2, 0],
]))
def testSparseReshapeMethod(self, shape, new_shape, n_batch, n_dense):
rng = jtu.rand_some_zero(self.rng())
arr = rng(shape, 'int32')
arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense)

arr2 = arr.reshape(new_shape)
arr2_sparse = arr_sparse.reshape(new_shape)

self.assertArraysEqual(arr2, arr2_sparse.todense())

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{shape}->{new_shape}_n_batch={n_batch}_n_dense={n_dense}_dimensions={dimensions}",
"shape": shape, "new_shape": new_shape, "n_batch": n_batch, "n_dense": n_dense,
"dimensions": dimensions}
for shape, new_shape, n_batch, n_dense, dimensions in [
[(2, 3, 4), (24,), 0, 0, None],
[(2, 3, 4), (24,), 0, 0, (0, 1, 2)],
[(2, 3, 4), (24,), 0, 0, (0, 2, 1)],
[(2, 3, 4), (24,), 0, 0, (1, 0, 2)],
[(2, 3, 4), (24,), 0, 0, (1, 2, 0)],
[(2, 3, 4), (24,), 0, 0, (2, 0, 1)],
[(2, 3, 4), (24,), 0, 0, (2, 1, 0)],
[(4, 2, 3), (2, 2, 6), 1, 0, (0, 1, 2)],
[(4, 2, 3), (2, 2, 6), 1, 0, (0, 2, 1)],
[(2, 3, 4), (6, 4), 2, 0, (0, 1, 2)],
[(2, 3, 4), (6, 4), 2, 0, (1, 0, 2)],
]))
def testSparseReshapeWithDimensions(self, shape, new_shape, n_batch, n_dense, dimensions):
rng = jtu.rand_some_zero(self.rng())
arr = rng(shape, 'int32')
arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense)

f = self.sparsify(lambda x: lax.reshape(x, new_shape, dimensions=dimensions))

arr2 = f(arr)
arr2_sparse = f(arr_sparse)

self.assertArraysEqual(arr2, arr2_sparse.todense())

def testSparseWhileLoop(self):
def cond_fun(params):
Expand Down

0 comments on commit 1a9a796

Please sign in to comment.