Skip to content

Commit

Permalink
jax.scipy.sparse.linalg: support sparse matrices as operators
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 5, 2022
1 parent 931bf36 commit 30fd817
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 14 deletions.
32 changes: 18 additions & 14 deletions jax/_src/scipy/sparse/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,12 @@ def _normalize_matvec(f):
raise ValueError(
f'linear operator must be a square matrix, but has shape: {f.shape}')
return partial(_dot, f)
elif hasattr(f, '__matmul__'):
if hasattr(f, 'shape') and len(f.shape) != 2 or f.shape[0] != f.shape[1]:
raise ValueError(
f'linear operator must be a square matrix, but has shape: {f.shape}')
return partial(operator.matmul, f)
else:
# TODO(shoyer): handle sparse arrays?
raise TypeError(
f'linear operator must be either a function or ndarray: {f}')

Expand Down Expand Up @@ -241,10 +245,10 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
Parameters
----------
A: ndarray or function
A: ndarray, function, or matmul-compatible object
2D array or function that calculates the linear map (matrix-vector
product) ``Ax`` when called like ``A(x)``. ``A`` must represent a
hermitian, positive definite matrix, and must return array(s) with the
product) ``Ax`` when called like ``A(x)`` or ``A @ x``. ``A`` must represent
a hermitian, positive definite matrix, and must return array(s) with the
same structure and shape as its argument.
b : array or tree of arrays
Right hand side of the linear system representing a single vector. Can be
Expand All @@ -269,7 +273,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
maxiter : integer
Maximum number of iterations. Iteration will stop after maxiter
steps even if the specified tolerance has not been achieved.
M : ndarray or function
M : ndarray, function, or matmul-compatible object
Preconditioner for A. The preconditioner should approximate the
inverse of A. Effective preconditioning dramatically improves the
rate of convergence, which implies that fewer iterations are needed
Expand Down Expand Up @@ -592,10 +596,10 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
Parameters
----------
A: ndarray or function
A: ndarray, function, or matmul-compatible object
2D array or function that calculates the linear map (matrix-vector
product) ``Ax`` when called like ``A(x)``. ``A`` must return array(s) with
the same structure and shape as its argument.
product) ``Ax`` when called like ``A(x)`` or ``A @ x``. ``A``
must return array(s) with the same structure and shape as its argument.
b : array or tree of arrays
Right hand side of the linear system representing a single vector. Can be
stored as an array or Python container of array(s) with any shape.
Expand Down Expand Up @@ -631,7 +635,7 @@ def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
starting from the solution found at the last iteration. If GMRES
halts or is very slow, decreasing this parameter may help.
Default is infinite.
M : ndarray or function
M : ndarray, function, or matmul-compatible object
Preconditioner for A. The preconditioner should approximate the
inverse of A. Effective preconditioning dramatically improves the
rate of convergence, which implies that fewer iterations are needed
Expand Down Expand Up @@ -709,11 +713,11 @@ def bicgstab(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
Parameters
----------
A: ndarray or function
A: ndarray, function, or matmul-compatible object
2D array or function that calculates the linear map (matrix-vector
product) ``Ax`` when called like ``A(x)``. ``A`` can represent any general
(nonsymmetric) linear operator, and function must return array(s) with the
same structure and shape as its argument.
product) ``Ax`` when called like ``A(x)`` or ``A @ x``. ``A`` can represent
any general (nonsymmetric) linear operator, and function must return array(s)
with the same structure and shape as its argument.
b : array or tree of arrays
Right hand side of the linear system representing a single vector. Can be
stored as an array or Python container of array(s) with any shape.
Expand All @@ -737,7 +741,7 @@ def bicgstab(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
maxiter : integer
Maximum number of iterations. Iteration will stop after maxiter
steps even if the specified tolerance has not been achieved.
M : ndarray or function
M : ndarray, function, or matmul-compatible object
Preconditioner for A. The preconditioner should approximate the
inverse of A. Effective preconditioning dramatically improves the
rate of convergence, which implies that fewer iterations are needed
Expand Down
33 changes: 33 additions & 0 deletions tests/lax_scipy_sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ def rand_sym_pos_def(rng, shape, dtype):
return matrix @ matrix.T.conj()


class CustomOperator:
def __init__(self, A):
self.A = A
self.shape = self.A.shape

def __matmul__(self, x):
return self.A @ x


class LaxBackedScipyTests(jtu.JaxTestCase):

def _fetch_preconditioner(self, preconditioner, A, rng=None):
Expand Down Expand Up @@ -164,6 +173,14 @@ def test_cg_pytree(self):
self.assertAlmostEqual(expected["a"], actual["a"], places=6)
self.assertAlmostEqual(expected["b"], actual["b"], places=6)

@jtu.skip_on_devices('tpu')
def test_cg_matmul(self):
A = CustomOperator(2 * jnp.eye(3))
b = jnp.arange(9.0).reshape(3, 3)
expected = b / 2
actual, _ = jax.scipy.sparse.linalg.cg(A, b)
self.assertAllClose(expected, actual)

def test_cg_errors(self):
A = lambda x: x
b = jnp.zeros((2,))
Expand Down Expand Up @@ -313,6 +330,14 @@ def test_bicgstab_weak_types(self):
x, _ = jax.scipy.sparse.linalg.bicgstab(lambda x: x, 1.0)
self.assertTrue(dtypes.is_weakly_typed(x))

@jtu.skip_on_devices('tpu')
def test_bicgstab_matmul(self):
A = CustomOperator(2 * jnp.eye(3))
b = jnp.arange(9.0).reshape(3, 3)
expected = b / 2
actual, _ = jax.scipy.sparse.linalg.bicgstab(A, b)
self.assertAllClose(expected, actual)

# GMRES
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
Expand Down Expand Up @@ -436,6 +461,14 @@ def test_gmres_pytree(self):
self.assertAlmostEqual(expected["a"], actual["a"], places=5)
self.assertAlmostEqual(expected["b"], actual["b"], places=5)

@jtu.skip_on_devices('tpu')
def test_gmres_matmul(self):
A = CustomOperator(2 * jnp.eye(3))
b = jnp.arange(9.0).reshape(3, 3)
expected = b / 2
actual, _ = jax.scipy.sparse.linalg.gmres(A, b)
self.assertAllClose(expected, actual)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_preconditioner={}".format(
Expand Down

0 comments on commit 30fd817

Please sign in to comment.