Skip to content

Commit

Permalink
Merge pull request jax-ml#11052 from jakevdp:x64-lax-scipy-test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 454940487
  • Loading branch information
jax authors committed Jun 14, 2022
2 parents b3130b7 + e888e7c commit d418e89
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 37 deletions.
8 changes: 4 additions & 4 deletions jax/_src/lax/eigh.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _projector_subspace(P, H, n, rank, maxiter=2):
X = _mask(X, (n, rank))

H_norm = jnp_linalg.norm(H)
thresh = 10 * jnp.finfo(X.dtype).eps * H_norm
thresh = 10.0 * float(jnp.finfo(X.dtype).eps) * H_norm

# First iteration skips the matmul.
def body_f_after_matmul(X):
Expand Down Expand Up @@ -190,7 +190,7 @@ def split_spectrum(H, n, split_point, V0=None):
rank: The dynamic size of the m subblock.
"""
N, _ = H.shape
H_shift = H - split_point * jnp.eye(N, dtype=H.dtype)
H_shift = H - (split_point * jnp.eye(N, dtype=split_point.dtype)).astype(H.dtype)
U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True, dynamic_shape=(n, n))
P = -0.5 * (U - _mask(jnp.eye(N, dtype=H.dtype), (n, n)))
rank = jnp.round(jnp.trace(jnp.real(P))).astype(jnp.int32)
Expand Down Expand Up @@ -331,6 +331,7 @@ def base_case(B, offset, b, agenda, blocks, eigenvectors):
eig_vals = _mask(eig_vals, (b,))
eig_vecs = jnp.dot(V, eig_vecs)

eig_vals = eig_vals.astype(eig_vecs.dtype)
blocks = _update_slice(blocks, eig_vals[:, None], (offset, 0), (b, b))
eigenvectors = _update_slice(eigenvectors, eig_vecs, (0, offset), (n, b))
return agenda, blocks, eigenvectors
Expand Down Expand Up @@ -370,12 +371,11 @@ def loop_cond(state):
i = min(2 * i, N)
buckets.append(i)
branches.append(partial(recursive_case, i))
buckets = jnp.array(buckets)
buckets = jnp.array(buckets, dtype='int32')

def loop_body(state):
agenda, blocks, eigenvectors = state
(offset, b), agenda = agenda.pop()

which = jnp.where(buckets < b, jnp.iinfo(jnp.int32).max, buckets)
choice = jnp.argmin(which)
return lax.switch(choice, branches, offset, b, agenda, blocks, eigenvectors)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/qdwh.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _qdwh(x, m, n, is_hermitian, max_iterations, eps):
if eps is None:
eps = float(jnp.finfo(x.dtype).eps)
alpha = (jnp.sqrt(jnp.linalg.norm(x, ord=1)) *
jnp.sqrt(jnp.linalg.norm(x, ord=jnp.inf)))
jnp.sqrt(jnp.linalg.norm(x, ord=jnp.inf))).astype(x.dtype)
l = eps

u = x / alpha
Expand Down
1 change: 1 addition & 0 deletions jax/_src/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,7 @@ def polar(a, side='right', *, method='qdwh', eps=None, max_iterations=None):
f"side='left', got {a.shape} with side={side}")
elif method == "svd":
u_svd, s_svd, vh_svd = lax_linalg.svd(a, full_matrices=False)
s_svd = s_svd.astype(u_svd.dtype)
unitary = u_svd @ vh_svd
if side == "right":
# a = u * p
Expand Down
39 changes: 22 additions & 17 deletions jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from functools import partial
from typing import Any, Optional, Tuple

import numpy as np
import scipy.special as osp_special
Expand All @@ -27,8 +28,6 @@
from jax._src.numpy.lax_numpy import _reduction_dims, _promote_args_inexact
from jax._src.numpy.util import _wraps

from typing import Optional, Tuple


@_wraps(osp_special.gammaln)
def gammaln(x):
Expand Down Expand Up @@ -680,7 +679,7 @@ def i1(x):


def _gen_recurrence_mask(
l_max: int, is_normalized: bool = True
l_max: int, is_normalized: bool, dtype: Any
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Generates mask for recurrence relation on the remaining entries.
Expand All @@ -697,7 +696,10 @@ def _gen_recurrence_mask(
"""

# Computes all coefficients.
m_mat, l_mat = jnp.mgrid[:l_max + 1, :l_max + 1]
m_mat, l_mat = jnp.meshgrid(
jnp.arange(l_max + 1, dtype=dtype),
jnp.arange(l_max + 1, dtype=dtype),
indexing='ij')
if is_normalized:
c0 = l_mat * l_mat
c1 = m_mat * m_mat
Expand All @@ -711,7 +713,7 @@ def _gen_recurrence_mask(

d0_mask_indices = jnp.triu_indices(l_max + 1, 1)
d1_mask_indices = jnp.triu_indices(l_max + 1, 2)
d_zeros = jnp.zeros((l_max + 1, l_max + 1))
d_zeros = jnp.zeros((l_max + 1, l_max + 1), dtype=dtype)
d0_mask = d_zeros.at[d0_mask_indices].set(d0[d0_mask_indices])
d1_mask = d_zeros.at[d1_mask_indices].set(d1[d1_mask_indices])

Expand All @@ -720,7 +722,7 @@ def _gen_recurrence_mask(
# j = jnp.arange(l_max + 1)[None, :, None]
# k = jnp.arange(l_max + 1)[None, None, :]
i, j, k = jnp.ogrid[:l_max + 1, :l_max + 1, :l_max + 1]
mask = 1.0 * (i + j - k == 0)
mask = (i + j - k == 0).astype(dtype)

d0_mask_3d = jnp.einsum('jk,ijk->ijk', d0_mask, mask)
d1_mask_3d = jnp.einsum('jk,ijk->ijk', d1_mask, mask)
Expand Down Expand Up @@ -762,24 +764,27 @@ def _gen_derivatives(p: jnp.ndarray,
'Negative orders for normalization is not implemented yet.')
else:
if num_l > 1:
l_vec = jnp.arange(1, num_l - 1)
l_vec = jnp.arange(1, num_l - 1, dtype=x.dtype)
p_p1 = p[1, 1:num_l - 1, :]
coeff = -1.0 / ((l_vec + 1) * l_vec)
update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1)
p_mm2_lm1 = p_mm2_lm1.at[1, 2:num_l, :].set(update_p_p1)

if num_l > 2:
l_vec = jnp.arange(2, num_l - 1)
l_vec = jnp.arange(2, num_l - 1, dtype=x.dtype)
p_p2 = p[2, 2:num_l - 1, :]
coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec * (l_vec - 1))
update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2)
p_mm2_lm1 = p_mm2_lm1.at[0, 3:num_l, :].set(update_p_p2)

m_mat, l_mat = jnp.mgrid[:num_m, :num_l]
m_mat, l_mat = jnp.meshgrid(
jnp.arange(num_m, dtype=x.dtype),
jnp.arange(num_l, dtype=x.dtype),
indexing='ij')

coeff_zeros = jnp.zeros((num_m, num_l))
coeff_zeros = jnp.zeros((num_m, num_l), dtype=x.dtype)
upper_0_indices = jnp.triu_indices(num_m, 0, num_l)
zero_vec = jnp.zeros((num_l,))
zero_vec = jnp.zeros((num_l,), dtype=x.dtype)

a0 = -0.5 / (m_mat - 1.0)
a0_masked = coeff_zeros.at[upper_0_indices].set(a0[upper_0_indices])
Expand Down Expand Up @@ -809,13 +814,13 @@ def _gen_derivatives(p: jnp.ndarray,

# Special treatment of the singularity at m = 1.
if num_m > 1:
l_vec = jnp.arange(num_l)
l_vec = jnp.arange(num_l, dtype=p.dtype)
g0 = jnp.einsum('i,ij->ij', (l_vec + 1) * l_vec, p[0, :, :])
if num_l > 2:
g0 = g0 - p[2, :, :]
p_derivative_m0 = jnp.einsum('j,ij->ij', 0.5 / jnp.sqrt(1 - x * x), g0)
p_derivative = p_derivative.at[1, :, :].set(p_derivative_m0)
p_derivative = p_derivative.at[1, 0, :].set(jnp.zeros((num_x,)))
p_derivative = p_derivative.at[1, 0, :].set(0)

return p_derivative

Expand Down Expand Up @@ -869,10 +874,10 @@ def _gen_associated_legendre(l_max: int,
of the ALFs at `x`; the dimensions in the sequence of order, degree, and
evalution points.
"""
p = jnp.zeros((l_max + 1, l_max + 1, x.shape[0]))
p = jnp.zeros((l_max + 1, l_max + 1, x.shape[0]), dtype=x.dtype)

a_idx = jnp.arange(1, l_max + 1)
b_idx = jnp.arange(l_max)
a_idx = jnp.arange(1, l_max + 1, dtype=x.dtype)
b_idx = jnp.arange(l_max, dtype=x.dtype)
if is_normalized:
initial_value = 0.5 / jnp.sqrt(jnp.pi) # The initial value p(0,0).
f_a = jnp.cumprod(-1 * jnp.sqrt(1.0 + 0.5 / a_idx))
Expand Down Expand Up @@ -901,7 +906,7 @@ def _gen_associated_legendre(l_max: int,

# Compute the remaining entries with recurrence.
d0_mask_3d, d1_mask_3d = _gen_recurrence_mask(
l_max, is_normalized=is_normalized)
l_max, is_normalized=is_normalized, dtype=x.dtype)

def body_fun(i, p_val):
coeff_0 = d0_mask_3d[i]
Expand Down
35 changes: 20 additions & 15 deletions tests/lax_scipy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t
]


@jtu.with_config(jax_numpy_dtype_promotion='strict')
class LaxBackedScipyTests(jtu.JaxTestCase):
"""Tests for LAX-backed Scipy implementation."""

Expand Down Expand Up @@ -249,6 +250,7 @@ def testLogSumExpNans(self):
if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes)))
for rec in JAX_SPECIAL_FUNCTION_RECORDS))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes,
test_autodiff, nondiff_argnums):
if (jtu.device_under_test() == "cpu" and
Expand Down Expand Up @@ -344,7 +346,7 @@ def scipy_fun(z, m=l_max, n=l_max):
return np.dstack(vals), np.dstack(derivs)

self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-5,
atol=3e-3)
atol=3e-3, check_dtypes=False)
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-5, atol=3e-3)

@parameterized.named_parameters(jtu.cases_from_list(
Expand Down Expand Up @@ -377,9 +379,11 @@ def scipy_fun(z, m=l_max, n=l_max):
a_normalized[m, l] = c2 * a[m, l]
return a_normalized

self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, rtol=1e-5, atol=1e-5)
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
rtol=1e-5, atol=1e-5, check_dtypes=False)
self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)

@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmAccuracy(self):
m = jnp.arange(-3, 3)[:, None]
n = jnp.arange(3, 6)
Expand All @@ -393,6 +397,7 @@ def testSphHarmAccuracy(self):

self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)

@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmOrderZeroDegreeZero(self):
"""Tests the spherical harmonics of order zero and degree zero."""
theta = jnp.array([0.3])
Expand All @@ -406,6 +411,7 @@ def testSphHarmOrderZeroDegreeZero(self):
self.assertAllClose(actual, expected, rtol=1.1e-7, atol=3e-8)

@jtu.skip_on_devices("rocm") # rtol and atol needs to be adjusted for ROCm
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmOrderZeroDegreeOne(self):
"""Tests the spherical harmonics of order one and degree zero."""
theta = jnp.array([2.0])
Expand All @@ -418,6 +424,7 @@ def testSphHarmOrderZeroDegreeOne(self):

self.assertAllClose(actual, expected, rtol=2e-7, atol=6e-8)

@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmOrderOneDegreeOne(self):
"""Tests the spherical harmonics of order one and degree one."""
theta = jnp.array([2.0])
Expand All @@ -432,11 +439,11 @@ def testSphHarmOrderOneDegreeOne(self):
self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8)

@parameterized.named_parameters(jtu.cases_from_list(
{'testcase_name': '_maxdegree={}_inputsize={}_dtype={}'.format(
l_max, num_z, dtype),
{'testcase_name': f'_maxdegree={l_max}_inputsize={num_z}_dtype={dtype.__name__}',
'l_max': l_max, 'num_z': num_z, 'dtype': dtype}
for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])
for dtype in jtu.dtypes.all_integer))
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype):
"""Tests against JIT compatibility and Numpy."""
n_max = l_max
Expand All @@ -458,6 +465,7 @@ def args_maker():
with self.subTest('Test against numpy.'):
self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn, args_maker)

@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testSphHarmCornerCaseWithWrongNmax(self):
"""Tests the corner case where `n_max` is not the maximum value of `n`."""
m = jnp.array([2])
Expand Down Expand Up @@ -527,7 +535,7 @@ def testPolar(
should_be_eye = np.matmul(unitary.conj().T, unitary)
else:
should_be_eye = np.matmul(unitary, unitary.conj().T)
tol = 500 * jnp.finfo(matrix.dtype).eps
tol = 500 * float(jnp.finfo(matrix.dtype).eps)
eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype)
with self.subTest('Test unitarity.'):
self.assertAllClose(
Expand All @@ -541,7 +549,7 @@ def testPolar(
ev = ev[np.abs(ev) > tol * np.linalg.norm(posdef)]
negative_ev = jnp.sum(ev < 0.)
with self.subTest('Test positive definiteness.'):
assert negative_ev == 0.
self.assertEqual(negative_ev, 0)

if side == "right":
recon = jnp.matmul(unitary, posdef, precision=lax.Precision.HIGHEST)
Expand All @@ -553,17 +561,14 @@ def testPolar(

@parameterized.named_parameters(jtu.cases_from_list(
{'testcase_name':
'_linear_size={}_seed={}_dtype={}_termination_size={}'.format(
linear_size, seed, jnp.dtype(dtype).name, termination_size
),
'linear_size': linear_size, 'seed': seed, 'dtype': dtype,
'_linear_size={}_dtype={}_termination_size={}'.format(
linear_size, jnp.dtype(dtype).name, termination_size),
'linear_size': linear_size, 'dtype': dtype,
'termination_size': termination_size}
for linear_size in linear_sizes
for seed in seeds
for dtype in jtu.dtypes.supported([jnp.float32, jnp.float64, jnp.complex64,
jnp.complex128])
for dtype in jtu.dtypes.floating + jtu.dtypes.complex
for termination_size in [1, 19]))
def test_spectral_dac_eigh(self, linear_size, seed, dtype, termination_size):
def test_spectral_dac_eigh(self, linear_size, dtype, termination_size):
if jtu.device_under_test() != "tpu" and termination_size != 1:
raise unittest.SkipTest(
"Termination sizes greater than 1 only work on TPU")
Expand All @@ -578,7 +583,7 @@ def test_spectral_dac_eigh(self, linear_size, seed, dtype, termination_size):
evs, V = jax._src.lax.eigh.eigh(H, termination_size=termination_size)
ev_exp, eV_exp = jnp.linalg.eigh(H)
HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST)
vV = evs[None, :] * V
vV = evs.astype(V.dtype)[None, :] * V
eps = jnp.finfo(H.dtype).eps
atol = jnp.linalg.norm(H) * eps
self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol)
Expand Down

0 comments on commit d418e89

Please sign in to comment.