Skip to content

Commit

Permalink
Merge pull request jax-ml#9544 from SaturdayGenfo:adds-matrix-sqrt
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 429264231
  • Loading branch information
jax authors committed Feb 17, 2022
2 parents bd2a6a0 + cb73232 commit 15295a8
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 1 deletion.
63 changes: 62 additions & 1 deletion jax/_src/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,18 @@ def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
del overwrite_a, overwrite_b, turbo, check_finite
return _eigh(a, b, lower, eigvals_only, eigvals, type)


@partial(jit, static_argnames=('output',))
def _schur(a, output):
if output == "complex":
a = a.astype(jnp.result_type(a.dtype, 0j))
return lax_linalg.schur(a)

@_wraps(scipy.linalg.schur)
def schur(a, output='real'):
if output not in ('real', 'complex'):
raise ValueError(
"Expected 'output' to be either 'real' or 'complex', got output={}.".format(output))
return _schur(a, output)

@_wraps(scipy.linalg.inv)
def inv(a, overwrite_a=False, check_finite=True):
Expand Down Expand Up @@ -595,3 +606,53 @@ def polar(a, side='right', method='qdwh', eps=None, maxiter=50):
unitary, posdef, _ = lax_polar.polar(a, side=side, method=method, eps=eps,
maxiter=maxiter)
return unitary, posdef

@jit
def _sqrtm_triu(T):
"""
Implements Björck, Å., & Hammarling, S. (1983).
"A Schur method for the square root of a matrix". Linear algebra and
its applications", 52, 127-140.
"""
diag = jnp.sqrt(jnp.diag(T))
n = diag.size
U = jnp.diag(diag)

def i_loop(l, data):
j, U = data
i = j - 1 - l
s = lax.fori_loop(i + 1, j, lambda k, val: val + U[i, k] * U[k, j], 0.0)
value = jnp.where(T[i, j] == s, 0.0,
(T[i, j] - s) / (diag[i] + diag[j]))
return j, U.at[i, j].set(value)

def j_loop(j, U):
_, U = lax.fori_loop(0, j, i_loop, (j, U))
return U

U = lax.fori_loop(0, n, j_loop, U)
return U

@jit
def _sqrtm(A):
T, Z = schur(A, output='complex')
sqrt_T = _sqrtm_triu(T)
return jnp.matmul(jnp.matmul(Z, sqrt_T, precision=lax.Precision.HIGHEST),
jnp.conj(Z.T), precision=lax.Precision.HIGHEST)

@_wraps(scipy.linalg.sqrtm,
lax_description="""
This differs from ``scipy.linalg.sqrtm`` in that the return type of
``jax.scipy.linalg.sqrtm`` is always ``complex64`` for 32-bit input,
and ``complex128`` for 64-bit input.
This function implements the complex Schur method described in [A]. It does not use recursive blocking
to speed up computations as a Sylvester Equation solver is not available yet in JAX.
[A] Björck, Å., & Hammarling, S. (1983).
"A Schur method for the square root of a matrix". Linear algebra and its applications, 52, 127-140.
""")
def sqrtm(A, blocksize=1):
if blocksize > 1:
raise NotImplementedError("Blocked version is not implemented yet.")
return _sqrtm(A)
2 changes: 2 additions & 0 deletions jax/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
lu_solve as lu_solve,
polar as polar,
qr as qr,
schur as schur,
sqrtm as sqrtm,
solve as solve,
solve_triangular as solve_triangular,
svd as svd,
Expand Down
78 changes: 78 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,84 @@ def expm(x):
return jsp.linalg.expm(x, upper_triangular=False, max_squarings=16)
jtu.check_grads(expm, (a,), modes=["fwd", "rev"], order=1, atol=tol,
rtol=tol)
@parameterized.named_parameters(
jtu.cases_from_list({
"testcase_name":
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype
} for shape in [(4, 4), (15, 15), (50, 50), (100, 100)]
for dtype in float_types + complex_types))
@jtu.skip_on_devices("gpu", "tpu")
def testSchur(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]

self._CheckAgainstNumpy(osp.linalg.schur, jsp.linalg.schur, args_maker)
self._CompileAndCheck(jsp.linalg.schur, args_maker)

@parameterized.named_parameters(
jtu.cases_from_list({
"testcase_name":
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape" : shape, "dtype" : dtype
} for shape in [(4, 4), (15, 15), (50, 50), (100, 100)]
for dtype in float_types + complex_types))
@jtu.skip_on_devices("gpu", "tpu")
def testSqrtmPSDMatrix(self, shape, dtype):
# Checks against scipy.linalg.sqrtm when the principal square root
# is guaranteed to be unique (i.e no negative real eigenvalue)
rng = jtu.rand_default(self.rng())
arg = rng(shape, dtype)
mat = arg @ arg.T
args_maker = lambda : [mat]
if dtype == np.float32 or dtype == np.complex64:
tol = 1e-4
else:
tol = 1e-8
self._CheckAgainstNumpy(osp.linalg.sqrtm,
jsp.linalg.sqrtm,
args_maker,
tol=tol,
check_dtypes=False)
self._CompileAndCheck(jsp.linalg.sqrtm, args_maker)

@parameterized.named_parameters(
jtu.cases_from_list({
"testcase_name":
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape" : shape, "dtype" : dtype
} for shape in [(4, 4), (15, 15), (50, 50), (100, 100)]
for dtype in float_types + complex_types))
@jtu.skip_on_devices("gpu", "tpu")
def testSqrtmGenMatrix(self, shape, dtype):
rng = jtu.rand_default(self.rng())
arg = rng(shape, dtype)
if dtype == np.float32 or dtype == np.complex64:
tol = 1e-3
else:
tol = 1e-8
R = jsp.linalg.sqrtm(arg)
self.assertAllClose(R @ R, arg, atol=tol, check_dtypes=False)

@parameterized.named_parameters(
jtu.cases_from_list({
"testcase_name":
"_diag={}".format((diag, dtype)),
"diag" : diag, "expected": expected, "dtype" : dtype
} for diag, expected in [([1, 0, 0], [1, 0, 0]), ([0, 4, 0], [0, 2, 0]),
([0, 0, 0, 9],[0, 0, 0, 3]),
([0, 0, 9, 0, 0, 4], [0, 0, 3, 0, 0, 2])]
for dtype in float_types + complex_types))
@jtu.skip_on_devices("gpu", "tpu")
def testSqrtmEdgeCase(self, diag, expected, dtype):
"""
Tests the zero numerator condition
"""
mat = jnp.diag(jnp.array(diag)).astype(dtype)
expected = jnp.diag(jnp.array(expected))
root = jsp.linalg.sqrtm(mat)

self.assertAllClose(root, expected, check_dtypes=False)


class LaxLinalgTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 15295a8

Please sign in to comment.