Skip to content

Commit

Permalink
Fix test failures due to type mismatches in linear algebra tests.
Browse files Browse the repository at this point in the history
Minor code cleanups.
  • Loading branch information
hawkinsp committed Dec 21, 2018
1 parent df59c51 commit a438645
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
6 changes: 3 additions & 3 deletions jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,20 @@ def slogdet(a):
raise ValueError(msg.format(a_shape))
lu, pivot = lax_linalg.lu(a)
diag = np.diagonal(lu, axis1=-2, axis2=-1)
is_zero = np.any(diag == 0, axis=-1)
is_zero = np.any(diag == np.array(0, dtype=dtype), axis=-1)
parity = np.count_nonzero(pivot != np.arange(a_shape[-1]), axis=-1)
if np.iscomplexobj(a):
sign = np.prod(diag / np.abs(diag))
else:
sign = 1
sign = np.array(1, dtype=dtype)
parity = parity + np.count_nonzero(diag < 0)
sign = np.where(is_zero,
np.array(0, dtype=dtype),
sign * np.array(-2 * (parity % 2) + 1, dtype=dtype))
logdet = np.where(
is_zero, np.array(-np.inf, dtype=dtype),
np.sum(np.log(np.abs(diag)), axis=-1))
return sign, logdet
return sign, np.real(logdet)


@_wraps(onp.linalg.det)
Expand Down
2 changes: 1 addition & 1 deletion jax/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
m, n = np.shape(a)
lu, pivots = lax_linalg.lu(a)
permutation = lax_linalg.lu_pivots_to_permutation(pivots, m)
p = np.array(permutation == np.arange(m)[:, None], dtype=dtype)
p = np.real(np.array(permutation == np.arange(m)[:, None], dtype=dtype))
k = min(m, n)
l = np.tril(lu, -1)[:, :k] + np.eye(m, k, dtype=dtype)
u = np.triu(lu)[:k, :]
Expand Down
3 changes: 2 additions & 1 deletion jaxlib/lapack.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ from scipy.linalg.cython_lapack cimport sgetrf, dgetrf, cgetrf, spotrf, dpotrf

import numpy as np
from jaxlib import xla_client
from jaxlib.xla_client import Shape

Shape = xla_client.Shape


cdef register_cpu_custom_call_target(fn_name, void* fn):
Expand Down
4 changes: 2 additions & 2 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def testLu(self, shape, dtype, rng):
self.skipTest("No LU implementation available")
args_maker = lambda: [rng(shape, dtype)]

self._CheckAgainstNumpy(osp.linalg.lu, jsp.linalg.lu, args_maker,
self._CheckAgainstNumpy(jsp.linalg.lu, osp.linalg.lu, args_maker,
check_dtypes=True, tol=1e-3)
self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True)

Expand All @@ -216,7 +216,7 @@ def testLuFactor(self, n, dtype, rng):
self.skipTest("No LU implementation available")
args_maker = lambda: [rng((n, n), dtype)]

self._CheckAgainstNumpy(osp.linalg.lu_factor, jsp.linalg.lu_factor,
self._CheckAgainstNumpy(jsp.linalg.lu_factor, osp.linalg.lu_factor,
args_maker, check_dtypes=True, tol=1e-3)
self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True)

Expand Down

0 comments on commit a438645

Please sign in to comment.