Skip to content

Commit

Permalink
Merge pull request jax-ml#15415 from jakevdp:sparse-add
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 522101591
  • Loading branch information
jax authors committed Apr 5, 2023
2 parents 0e549ac + 05f32a7 commit 7b4e579
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
14 changes: 13 additions & 1 deletion jax/experimental/sparse/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ def _transpose_sparse(spenv, *spvalues, permutation):

def _add_sparse(spenv, *spvalues):
X, Y = spvalues
out_shape = lax.broadcast_shapes(X.shape, Y.shape)
if X.is_sparse() and Y.is_sparse():
if X.shape != Y.shape:
raise NotImplementedError("Addition between sparse matrices of different shapes.")
Expand All @@ -621,7 +622,18 @@ def _add_sparse(spenv, *spvalues):
out_data = lax.concatenate([spenv.data(X), spenv.data(Y)], dimension=spenv.indices(X).ndim - 2)
out_spvalue = spenv.sparse(X.shape, out_data, out_indices)
else:
raise NotImplementedError("Addition between sparse and dense array.")
if Y.is_sparse():
X, Y = Y, X
assert X.is_sparse() and Y.is_dense()
if Y.shape != out_shape:
raise NotImplementedError(
"Addition between a sparse array X and a dense array Y is not implemented when "
"the output shape is larger than Y.shape. This is to prevent silent densification "
"of a large sparse array. If this is your intent, you can explicitly cast the sparse "
"array to a dense matrix.")
X_promoted, Y_promoted = spvalues_to_arrays(spenv, (X, Y))
out = X_promoted.todense() + Y_promoted
out_spvalue = spenv.dense(out)

return (out_spvalue,)

Expand Down
16 changes: 16 additions & 0 deletions tests/sparsify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,22 @@ def testSparseAdd(self):

self.assertAllClose(out.todense(), x.todense() + y.todense())

# Sparse + dense: supported
x = BCOO.fromdense(jnp.arange(6.)).reshape(2, 3)
y = jnp.ones((2, 3))

out = self.sparsify(operator.add)(x, y)
self.assertAllClose(out, x.todense() + y)

out = self.sparsify(operator.add)(y, x)
self.assertAllClose(out, x.todense() + y)

# Sparse + dense: unsupported
msg = "Addition between a sparse array X and a dense array Y is not implemented"
with self.assertRaisesRegex(NotImplementedError, msg):
self.sparsify(operator.add)(x, 1.)


@jtu.sample_product(
[dict(shape=shape, n_batch=n_batch, n_dense=n_dense)
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
Expand Down

0 comments on commit 7b4e579

Please sign in to comment.