Skip to content

Commit

Permalink
[sparse] support type promotion in CSR/COO matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Nov 17, 2021
1 parent bb3f198 commit 50ce1db
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
19 changes: 13 additions & 6 deletions jax/experimental/sparse/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from jax._src.lib import cusparse
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.numpy.lax_numpy import _promote_dtypes
import jax.numpy as jnp

xb = xla_bridge
Expand Down Expand Up @@ -807,10 +808,12 @@ def todense(self):
return csr_todense(self.data, self.indices, self.indptr, shape=self.shape)

def matvec(self, v):
return csr_matvec(self.data, self.indices, self.indptr, v, shape=self.shape)
data, v = _promote_dtypes(self.data, v)
return csr_matvec(data, self.indices, self.indptr, v, shape=self.shape)

def matmat(self, B):
return csr_matmat(self.data, self.indices, self.indptr, B, shape=self.shape)
data, B = _promote_dtypes(self.data, B)
return csr_matmat(data, self.indices, self.indptr, B, shape=self.shape)

def transpose(self, axes=None):
assert axes is None
Expand Down Expand Up @@ -844,10 +847,12 @@ def todense(self):
return csr_todense(self.data, self.indices, self.indptr, shape=self.shape[::-1]).T

def matvec(self, v):
return csr_matvec(self.data, self.indices, self.indptr, v, shape=self.shape[::-1], transpose=True)
data, v = _promote_dtypes(self.data, v)
return csr_matvec(data, self.indices, self.indptr, v, shape=self.shape[::-1], transpose=True)

def matmat(self, B):
return csr_matmat(self.data, self.indices, self.indptr, B, shape=self.shape[::-1], transpose=True)
data, B = _promote_dtypes(self.data, B)
return csr_matmat(data, self.indices, self.indptr, B, shape=self.shape[::-1], transpose=True)

def transpose(self, axes=None):
assert axes is None
Expand Down Expand Up @@ -881,10 +886,12 @@ def todense(self):
return coo_todense(self.data, self.row, self.col, shape=self.shape)

def matvec(self, v):
return coo_matvec(self.data, self.row, self.col, v, shape=self.shape)
data, v = _promote_dtypes(self.data, v)
return coo_matvec(data, self.row, self.col, v, shape=self.shape)

def matmat(self, B):
return coo_matmat(self.data, self.row, self.col, B, shape=self.shape)
data, B = _promote_dtypes(self.data, B)
return coo_matmat(data, self.row, self.col, B, shape=self.shape)

def transpose(self, axes=None):
assert axes is None
Expand Down
6 changes: 6 additions & 0 deletions tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,9 +1488,15 @@ def test_matmul(self, shape, dtype, Obj, bshape):
rng_b = jtu.rand_default(self.rng())
M = rng(shape, dtype)
Msp = Obj.fromdense(M)

# Test matching type
x = rng_b(bshape, dtype)
x = jnp.asarray(x)
self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL)

# Test mismatched type
x = rng_b(bshape, np.int32)
x = jnp.asarray(x)
self.assertAllClose(M @ x, Msp @ x, rtol=MATMUL_TOL)

@parameterized.named_parameters(jtu.cases_from_list(
Expand Down

0 comments on commit 50ce1db

Please sign in to comment.