Skip to content

Commit

Permalink
Implement np.linalg.slogdet.
Browse files Browse the repository at this point in the history
Change implementation of np.linalg.logdet to call np.linalg.slogdet.

Add support for complex64 LU decomposition.
  • Loading branch information
hawkinsp committed Dec 21, 2018
1 parent 3db941f commit b68c93d
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 15 deletions.
5 changes: 3 additions & 2 deletions jax/lax_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,11 @@ def lu_abstract_eval(operand):
lu_p.def_abstract_eval(lu_abstract_eval)
xla.translations[lu_p] = lu_translation_rule

_lu_cpu_types = {np.float32, np.float64, np.complex64}

def lu_cpu_translation_rule(c, operand):
shape = c.GetShape(operand)
if len(shape.dimensions()) == 2 and (
shape.element_type() == np.float32 or shape.element_type() == np.float64):
if len(shape.dimensions()) == 2 and shape.element_type().type in _lu_cpu_types:
out = lapack.jax_getrf(c, operand)
lu = c.GetTupleElement(out, 0)
# Subtract 1 from the pivot to get 0-based indices.
Expand Down
5 changes: 4 additions & 1 deletion jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,9 +992,12 @@ def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None):
@_wraps(onp.diagonal)
def diagonal(a, offset=0, axis1=0, axis2=1):
a_shape = shape(a)
a_ndims = len(a_shape)

# Move the two dimensions to the end.
perm = [i for i in range(len(a_shape)) if i != axis1 and i != axis2]
axis1 %= a_ndims
axis2 %= a_ndims
perm = [i for i in range(a_ndims) if i != axis1 and i != axis2]
perm = perm + [axis1, axis2]
a = lax.transpose(a, perm)

Expand Down
31 changes: 25 additions & 6 deletions jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,35 @@ def cholesky(a):
return lax_linalg.cholesky(a)


@_wraps(onp.linalg.det)
def det(a):
@_wraps(onp.linalg.slogdet)
def slogdet(a):
dtype = lax._dtype(a)
a_shape = np.shape(a)
if len(a_shape) != 2 or a_shape[-1] != a_shape[-2]:
msg = "Argument to det() must be a square matrix, got {}"
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
msg = "Argument to slogdet() must have shape [..., n, n], got {}"
raise ValueError(msg.format(a_shape))
lu, pivot = lax_linalg.lu(a)
parity = np.count_nonzero(pivot != np.arange(a_shape[-1])) % 2
return np.prod(np.diagonal(lu)) * np.array(-2 * parity + 1, dtype=dtype)
diag = np.diagonal(lu, axis1=-2, axis2=-1)
is_zero = np.any(diag == 0, axis=-1)
parity = np.count_nonzero(pivot != np.arange(a_shape[-1]), axis=-1)
if np.iscomplexobj(a):
sign = np.prod(diag / np.abs(diag))
else:
sign = 1
parity = parity + np.count_nonzero(diag < 0)
sign = np.where(is_zero,
np.array(0, dtype=dtype),
sign * np.array(-2 * (parity % 2) + 1, dtype=dtype))
logdet = np.where(
is_zero, np.array(-np.inf, dtype=dtype),
np.sum(np.log(np.abs(diag)), axis=-1))
return sign, logdet


@_wraps(onp.linalg.det)
def det(a):
sign, logdet = slogdet(a)
return sign * np.exp(logdet)


@_wraps(onp.linalg.inv)
Expand Down
21 changes: 20 additions & 1 deletion jaxlib/lapack.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ from libcpp.string cimport string
from cpython.pycapsule cimport PyCapsule_New

from scipy.linalg.cython_blas cimport strsm, dtrsm
from scipy.linalg.cython_lapack cimport sgetrf, dgetrf, spotrf, dpotrf
from scipy.linalg.cython_lapack cimport sgetrf, dgetrf, cgetrf, spotrf, dpotrf

import numpy as np
from jaxlib import xla_client
Expand Down Expand Up @@ -182,6 +182,23 @@ cdef void lapack_dgetrf(void* out_tuple, void** data) nogil:
register_cpu_custom_call_target(b"lapack_dgetrf", <void*>(lapack_dgetrf))


cdef void lapack_cgetrf(void* out_tuple, void** data) nogil:
cdef int m = (<int32_t*>(data[0]))[0]
cdef int n = (<int32_t*>(data[1]))[0]
cdef const float complex* a_in = <float complex*>(data[2])

cdef void** out = <void**>(out_tuple)
cdef float complex* a_out = <float complex*>(out[0])
cdef int* ipiv = <int*>(out[1])
cdef int* info = <int*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in, m * n * sizeof(float complex))

cgetrf(&m, &n, a_out, &m, ipiv, info)

register_cpu_custom_call_target(b"lapack_cgetrf", <void*>(lapack_cgetrf))


def jax_getrf(c, a):
assert sizeof(int32_t) == sizeof(int)

Expand All @@ -192,6 +209,8 @@ def jax_getrf(c, a):
fn = b"lapack_sgetrf"
elif dtype == np.float64:
fn = b"lapack_dgetrf"
elif dtype == np.complex64:
fn = b"lapack_cgetrf"
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))

Expand Down
4 changes: 3 additions & 1 deletion tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,9 @@ def testDiag(self, shape, dtype, k, rng):
"axis2": axis2, "rng": jtu.rand_default()}
for dtype in default_dtypes
for shape in [shape for shape in all_shapes if len(shape) >= 2]
for (axis1, axis2) in itertools.combinations(range(len(shape)), 2)
for axis1 in range(-len(shape), len(shape))
for axis2 in [a for a in range(-len(shape), len(shape))
if a % len(shape) != axis1 % len(shape)]
for offset in list(range(-4, 4))))
def testDiagonal(self, shape, dtype, offset, axis1, axis2, rng):
onp_fun = lambda arg: onp.diagonal(arg, offset, axis1, axis2)
Expand Down
27 changes: 23 additions & 4 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def float_types():
return set(onp.dtype(xla_bridge.canonicalize_dtype(dtype))
for dtype in [onp.float32, onp.float64])

def complex_types():
return {onp.complex64}


class NumpyLinalgTest(jtu.JaxTestCase):

Expand All @@ -68,8 +71,8 @@ def args_maker():
{"testcase_name":
"_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
"n": n, "dtype": dtype, "rng": rng}
for n in [0, 4, 5, 200]
for dtype in float_types()
for n in [0, 4, 5, 50]
for dtype in float_types() | complex_types()
for rng in [jtu.rand_default()]))
@jtu.skip_on_devices("gpu", "tpu")
def testDet(self, n, dtype, rng):
Expand All @@ -81,6 +84,22 @@ def testDet(self, n, dtype, rng):
check_dtypes=True, tol=1e-3)
self._CompileAndCheck(np.linalg.det, args_maker, check_dtypes=True)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
"n": n, "dtype": dtype, "rng": rng}
for n in [0, 4, 10, 200]
for dtype in float_types() | complex_types()
for rng in [jtu.rand_default()]))
@jtu.skip_on_devices("gpu", "tpu")
def testSlogdet(self, n, dtype, rng):
if not hasattr(lapack, "jax_getrf"):
self.skipTest("No LU implementation available")
args_maker = lambda: [rng((n, n), dtype)]

self._CheckAgainstNumpy(onp.linalg.slogdet, np.linalg.slogdet, args_maker,
check_dtypes=True, tol=1e-3)
self._CompileAndCheck(np.linalg.slogdet, args_maker, check_dtypes=True)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_fullmatrices={}".format(
Expand Down Expand Up @@ -169,7 +188,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
"_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype, "rng": rng}
for shape in [(1, 1), (4, 5), (10, 5), (50, 50)]
for dtype in float_types()
for dtype in float_types() | complex_types()
for rng in [jtu.rand_default()]))
@jtu.skip_on_devices("gpu", "tpu")
def testLu(self, shape, dtype, rng):
Expand All @@ -188,7 +207,7 @@ def testLu(self, shape, dtype, rng):
"_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
"n": n, "dtype": dtype, "rng": rng}
for n in [1, 4, 5, 200]
for dtype in float_types()
for dtype in float_types() | complex_types()
for rng in [jtu.rand_default()]))
@jtu.skip_on_devices("gpu", "tpu")
def testLuFactor(self, n, dtype, rng):
Expand Down

0 comments on commit b68c93d

Please sign in to comment.