diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index ebf0fa7ea8d8..3391130eaf5c 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -103,7 +103,18 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False, del overwrite_a, overwrite_b, turbo, check_finite return _eigh(a, b, lower, eigvals_only, eigvals, type) - +@partial(jit, static_argnames=('output',)) +def _schur(a, output): + if output == "complex": + a = a.astype(jnp.result_type(a.dtype, 0j)) + return lax_linalg.schur(a) + +@_wraps(scipy.linalg.schur) +def schur(a, output='real'): + if output not in ('real', 'complex'): + raise ValueError( + "Expected 'output' to be either 'real' or 'complex', got output={}.".format(output)) + return _schur(a, output) @_wraps(scipy.linalg.inv) def inv(a, overwrite_a=False, check_finite=True): @@ -595,3 +606,53 @@ def polar(a, side='right', method='qdwh', eps=None, maxiter=50): unitary, posdef, _ = lax_polar.polar(a, side=side, method=method, eps=eps, maxiter=maxiter) return unitary, posdef + +@jit +def _sqrtm_triu(T): + """ + Implements Björck, Å., & Hammarling, S. (1983). + "A Schur method for the square root of a matrix". Linear algebra and + its applications", 52, 127-140. + """ + diag = jnp.sqrt(jnp.diag(T)) + n = diag.size + U = jnp.diag(diag) + + def i_loop(l, data): + j, U = data + i = j - 1 - l + s = lax.fori_loop(i + 1, j, lambda k, val: val + U[i, k] * U[k, j], 0.0) + value = jnp.where(T[i, j] == s, 0.0, + (T[i, j] - s) / (diag[i] + diag[j])) + return j, U.at[i, j].set(value) + + def j_loop(j, U): + _, U = lax.fori_loop(0, j, i_loop, (j, U)) + return U + + U = lax.fori_loop(0, n, j_loop, U) + return U + +@jit +def _sqrtm(A): + T, Z = schur(A, output='complex') + sqrt_T = _sqrtm_triu(T) + return jnp.matmul(jnp.matmul(Z, sqrt_T, precision=lax.Precision.HIGHEST), + jnp.conj(Z.T), precision=lax.Precision.HIGHEST) + +@_wraps(scipy.linalg.sqrtm, + lax_description=""" +This differs from ``scipy.linalg.sqrtm`` in that the return type of +``jax.scipy.linalg.sqrtm`` is always ``complex64`` for 32-bit input, +and ``complex128`` for 64-bit input. + +This function implements the complex Schur method described in [A]. It does not use recursive blocking +to speed up computations as a Sylvester Equation solver is not available yet in JAX. + +[A] Björck, Å., & Hammarling, S. (1983). + "A Schur method for the square root of a matrix". Linear algebra and its applications, 52, 127-140. +""") +def sqrtm(A, blocksize=1): + if blocksize > 1: + raise NotImplementedError("Blocked version is not implemented yet.") + return _sqrtm(A) diff --git a/jax/scipy/linalg.py b/jax/scipy/linalg.py index 589dd9162c11..2958ce9d43b5 100644 --- a/jax/scipy/linalg.py +++ b/jax/scipy/linalg.py @@ -30,6 +30,8 @@ lu_solve as lu_solve, polar as polar, qr as qr, + schur as schur, + sqrtm as sqrtm, solve as solve, solve_triangular as solve_triangular, svd as svd, diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 5d3863eeabb0..ea34b5be8f3d 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1371,6 +1371,84 @@ def expm(x): return jsp.linalg.expm(x, upper_triangular=False, max_squarings=16) jtu.check_grads(expm, (a,), modes=["fwd", "rev"], order=1, atol=tol, rtol=tol) + @parameterized.named_parameters( + jtu.cases_from_list({ + "testcase_name": + "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), + "shape": shape, "dtype": dtype + } for shape in [(4, 4), (15, 15), (50, 50), (100, 100)] + for dtype in float_types + complex_types)) + @jtu.skip_on_devices("gpu", "tpu") + def testSchur(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + self._CheckAgainstNumpy(osp.linalg.schur, jsp.linalg.schur, args_maker) + self._CompileAndCheck(jsp.linalg.schur, args_maker) + + @parameterized.named_parameters( + jtu.cases_from_list({ + "testcase_name": + "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), + "shape" : shape, "dtype" : dtype + } for shape in [(4, 4), (15, 15), (50, 50), (100, 100)] + for dtype in float_types + complex_types)) + @jtu.skip_on_devices("gpu", "tpu") + def testSqrtmPSDMatrix(self, shape, dtype): + # Checks against scipy.linalg.sqrtm when the principal square root + # is guaranteed to be unique (i.e no negative real eigenvalue) + rng = jtu.rand_default(self.rng()) + arg = rng(shape, dtype) + mat = arg @ arg.T + args_maker = lambda : [mat] + if dtype == np.float32 or dtype == np.complex64: + tol = 1e-4 + else: + tol = 1e-8 + self._CheckAgainstNumpy(osp.linalg.sqrtm, + jsp.linalg.sqrtm, + args_maker, + tol=tol, + check_dtypes=False) + self._CompileAndCheck(jsp.linalg.sqrtm, args_maker) + + @parameterized.named_parameters( + jtu.cases_from_list({ + "testcase_name": + "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), + "shape" : shape, "dtype" : dtype + } for shape in [(4, 4), (15, 15), (50, 50), (100, 100)] + for dtype in float_types + complex_types)) + @jtu.skip_on_devices("gpu", "tpu") + def testSqrtmGenMatrix(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + arg = rng(shape, dtype) + if dtype == np.float32 or dtype == np.complex64: + tol = 1e-3 + else: + tol = 1e-8 + R = jsp.linalg.sqrtm(arg) + self.assertAllClose(R @ R, arg, atol=tol, check_dtypes=False) + + @parameterized.named_parameters( + jtu.cases_from_list({ + "testcase_name": + "_diag={}".format((diag, dtype)), + "diag" : diag, "expected": expected, "dtype" : dtype + } for diag, expected in [([1, 0, 0], [1, 0, 0]), ([0, 4, 0], [0, 2, 0]), + ([0, 0, 0, 9],[0, 0, 0, 3]), + ([0, 0, 9, 0, 0, 4], [0, 0, 3, 0, 0, 2])] + for dtype in float_types + complex_types)) + @jtu.skip_on_devices("gpu", "tpu") + def testSqrtmEdgeCase(self, diag, expected, dtype): + """ + Tests the zero numerator condition + """ + mat = jnp.diag(jnp.array(diag)).astype(dtype) + expected = jnp.diag(jnp.array(expected)) + root = jsp.linalg.sqrtm(mat) + + self.assertAllClose(root, expected, check_dtypes=False) class LaxLinalgTest(jtu.JaxTestCase):