Skip to content

Commit

Permalink
Add support for select and select_range to jax.scipy.linalg.eigh_trid…
Browse files Browse the repository at this point in the history
…iagonal().

Credit to rmlarsen@.
  • Loading branch information
hawkinsp committed May 6, 2021
1 parent 8ded7ad commit bf6b59e
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 32 deletions.
46 changes: 28 additions & 18 deletions jax/_src/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,12 @@ def block_diag(*arrs):
# TODO(phawkins): use static_argnames when jaxlib 0.1.66 is the minimum and
# remove this wrapper.
@_wraps(scipy.linalg.eigh_tridiagonal)
def eigh_tridiagonal(d, e, tol=None, eigvals_only=False):
return _eigh_tridiagonal(d, e, tol, eigvals_only)
def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a',
select_range=None, tol=None):
return _eigh_tridiagonal(d, e, eigvals_only, select, select_range, tol)

@partial(jit, static_argnums=(3,))
def _eigh_tridiagonal(d, e, tol, eigvals_only):
@partial(jit, static_argnums=(2, 3, 4))
def _eigh_tridiagonal(d, e, eigvals_only, select, select_range, tol):
if not eigvals_only:
raise NotImplementedError("Calculation of eigenvectors is not implemented")

Expand Down Expand Up @@ -542,13 +543,22 @@ def cond(iqc):
# lambda_est_max = -lambda_est_min, we have to take as many bisection steps
# as there are bits in the mantissa plus 1.
# The proof is left as an exercise to the reader.
max_it = finfo.nmant + 2

# We want to find [lambda_0, lambda_1, ..., lambda_{n-1}], such that the
# number of eigenvalues of T less than lambda_i is i.
# TODO(rmlarsen): Extend this logic to support the "select" keyword to
# to specify a subset of eigenvalues to compute.
target_counts = jnp.arange(n)
max_it = finfo.nmant + 1

# Determine the indices of the desired eigenvalues, based on select and
# select_range.
if select == 'a':
target_counts = jnp.arange(n)
elif select == 'i':
if select_range[0] > select_range[1]:
raise ValueError('Got empty index range in select_range.')
target_counts = jnp.arange(select_range[0], select_range[1] + 1)
elif select == 'v':
# TODO(phawkins): requires dynamic shape support.
raise NotImplementedError("eigh_tridiagonal(..., select='v') is not "
"implemented")
else:
raise ValueError("'select must have a value in {'a', 'i', 'v'}.")

# Run binary search for all desired eigenvalues in parallel, starting from
# the interval lightly wider than the estimated
Expand All @@ -557,15 +567,15 @@ def cond(iqc):
norm_slack = jnp.array(n, alpha.dtype) * fudge * finfo.eps * t_norm
lower = lambda_est_min - norm_slack - 2 * fudge * pivmin
upper = lambda_est_max + norm_slack + fudge * pivmin
lower = jnp.broadcast_to(lower, shape=target_counts.shape)
upper = jnp.broadcast_to(upper, shape=target_counts.shape)
mid = 0.5 * (upper + lower)

# Pre-broadcast the fixed scalars used in the Sturm sequence for improved
# Pre-broadcast the scalars used in the Sturm sequence for improved
# performance.
pivmin = jnp.broadcast_to(pivmin, target_counts.shape)
alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation,
target_counts.shape)
target_shape = jnp.shape(target_counts)
lower = jnp.broadcast_to(lower, shape=target_shape)
upper = jnp.broadcast_to(upper, shape=target_shape)
mid = 0.5 * (upper + lower)
pivmin = jnp.broadcast_to(pivmin, target_shape)
alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape)

# Start parallel binary searches.
def cond(args):
Expand Down
63 changes: 49 additions & 14 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,23 @@ def expm(x):

class EighTridiagonalTest(jtu.JaxTestCase):

def run_test(self, alpha, beta):
n = alpha.shape[-1]
# scipy.linalg.eigh_tridiagonal doesn't support complex inputs, so for
# this we call the slower numpy.linalg.eigh.
if np.issubdtype(alpha.dtype, np.complexfloating):
tridiagonal = np.diag(alpha) + np.diag(beta, 1) + np.diag(
np.conj(beta), -1)
eigvals_expected, _ = np.linalg.eigh(tridiagonal)
else:
eigvals_expected = scipy.linalg.eigh_tridiagonal(
alpha, beta, eigvals_only=True)
eigvals = jax.scipy.linalg.eigh_tridiagonal(
alpha, beta, eigvals_only=True)
finfo = np.finfo(alpha.dtype)
atol = 4 * np.sqrt(n) * finfo.eps * np.amax(np.abs(eigvals_expected))
self.assertAllClose(eigvals_expected, eigvals, atol=atol, rtol=1e-4)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_n={n}_dtype={dtype.__name__}",
"n": n, "dtype": dtype}
Expand All @@ -1424,21 +1441,39 @@ def testToeplitz(self, n, dtype):
for a, b in [[2, -1], [1, 0], [0, 1], [-1e10, 1e10], [-1e-10, 1e-10]]:
alpha = a * np.ones([n], dtype=dtype)
beta = b * np.ones([n - 1], dtype=dtype)
# scipy.linalg.eigh_tridiagonal doesn't support complex inputs, so for
# this we call the slower numpy.linalg.eigh.
if np.issubdtype(alpha.dtype, np.complexfloating):
tridiagonal = np.diag(alpha) + np.diag(beta, 1) + np.diag(
np.conj(beta), -1)
eigvals_expected, _ = np.linalg.eigh(tridiagonal)
else:
eigvals_expected = scipy.linalg.eigh_tridiagonal(
alpha, beta, eigvals_only=True)
eigvals = jax.scipy.linalg.eigh_tridiagonal(
alpha, beta, eigvals_only=True)
finfo = np.finfo(dtype)
atol = 4 * np.sqrt(n) * finfo.eps * np.amax(np.abs(eigvals_expected))
self.assertAllClose(eigvals_expected, eigvals, atol=atol, rtol=1e-4)
self.run_test(alpha, beta)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_n={n}_dtype={dtype.__name__}",
"n": n, "dtype": dtype}
for n in [1, 2, 3, 7, 8, 100]
for dtype in float_types + complex_types))
def testRandomUniform(self, n, dtype):
alpha = jtu.rand_uniform(self.rng())((n,), dtype)
beta = jtu.rand_uniform(self.rng())((n - 1,), dtype)
self.run_test(alpha, beta)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={dtype.__name__}",
"dtype": dtype}
for dtype in float_types + complex_types))
def testSelect(self, dtype):
n = 5
alpha = jtu.rand_uniform(self.rng())((n,), dtype)
beta = jtu.rand_uniform(self.rng())((n - 1,), dtype)
eigvals_all = jax.scipy.linalg.eigh_tridiagonal(alpha, beta, select="a",
eigvals_only=True)
eps = np.finfo(alpha.dtype).eps
atol = 2 * n * eps
for first in range(n - 1):
for last in range(first + 1, n - 1):
# Check that we get the expected eigenvalues by selecting by
# index range.
eigvals_index = jax.scipy.linalg.eigh_tridiagonal(
alpha, beta, select="i", select_range=(first, last),
eigvals_only=True)
self.assertAllClose(
eigvals_all[first:(last + 1)], eigvals_index, atol=atol)


if __name__ == "__main__":
Expand Down

0 comments on commit bf6b59e

Please sign in to comment.