Skip to content

Commit

Permalink
[sparse] ensure shapes are represented as tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Nov 16, 2021
1 parent 476ca94 commit 5c31e6d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion jax/experimental/sparse/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ def _safe_asarray(args):
return map(_asarray_or_float0, args)

def __init__(self, args, *, shape):
self.shape = shape
self.shape = tuple(shape)

def __repr__(self):
name = self.__class__.__name__
Expand Down
6 changes: 3 additions & 3 deletions tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,10 +1340,10 @@ def f(X, y):
class SparseObjectTest(jtu.JaxTestCase):
def test_repr(self):
M = sparse.BCOO.fromdense(jnp.arange(5, dtype='float32'))
assert repr(M) == "BCOO(float32[5], nse=4)"
self.assertEqual(repr(M), "BCOO(float32[5], nse=4)")

M_invalid = sparse.BCOO(([], []), shape=100)
assert repr(M_invalid) == "BCOO(<invalid>)"
M_invalid = sparse.BCOO(([], []), shape=(100,))
self.assertEqual(repr(M_invalid), "BCOO(<invalid>)")

@parameterized.named_parameters(
{"testcase_name": "_{}".format(Obj.__name__), "Obj": Obj}
Expand Down

0 comments on commit 5c31e6d

Please sign in to comment.