Skip to content

Commit

Permalink
[sparse] bcoo_broadcast_in_dim: default to adding leading batch dimen…
Browse files Browse the repository at this point in the history
…sions
  • Loading branch information
jakevdp committed Apr 27, 2022
1 parent 7c6a550 commit 0175479
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
4 changes: 3 additions & 1 deletion jax/experimental/sparse/bcoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,7 +1501,9 @@ def _bcoo_broadcast_in_dim(data, indices, *, spinfo, shape, broadcast_dimensions
if max(sparse_dims, default=0) > min(dense_dims, default=len(shape)):
raise ValueError("Cannot mix sparse and dense dimensions during broadcast_in_dim")

new_n_batch = props.n_batch and 1 + max(broadcast_dimensions[:props.n_batch])
# All new dimensions preceding a sparse or dense dimension are batch dimensions:
new_n_batch = min(broadcast_dimensions[props.n_batch:], default=len(shape))
# TODO(jakevdp): Should new trailing dimensions be dense by default?
new_n_dense = props.n_dense and len(shape) - min(broadcast_dimensions[-props.n_dense:])
new_n_sparse = len(shape) - new_n_batch - new_n_dense

Expand Down
4 changes: 4 additions & 0 deletions tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1866,11 +1866,15 @@ def test_bcoo_broadcast_in_dim(self, shape, dtype, n_batch, n_dense):
x = jnp.array(rng(shape, dtype))
xsp = sparse.BCOO.fromdense(x, n_batch=n_batch, n_dense=n_dense)

self.assertEqual(xsp[None].n_batch, xsp.n_batch + 1)
self.assertArraysEqual(xsp[None].todense(), x[None])

if len(shape) >= 1:
self.assertEqual(xsp[:, None].n_batch, xsp.n_batch if xsp.n_batch < 1 else xsp.n_batch + 1)
self.assertArraysEqual(xsp[:, None].todense(), x[:, None])
self.assertArraysEqual(xsp[:, None, None].todense(), x[:, None, None])
if len(shape) >= 2:
self.assertEqual(xsp[:, :, None].n_batch, xsp.n_batch if xsp.n_batch < 2 else xsp.n_batch + 1)
self.assertArraysEqual(xsp[:, :, None].todense(), x[:, :, None])
self.assertArraysEqual(xsp[:, None, :, None].todense(), x[:, None, :, None])

Expand Down

0 comments on commit 0175479

Please sign in to comment.