Skip to content

Commit

Permalink
ENH: Implement most linalg operations for 0x0 matrices
Browse files Browse the repository at this point in the history
Fixes numpy#8212

det: return ones of the right shape
slogdet: return sign=ones, log=zeros of the right shape
pinv, eigvals(h?), eig(h?): return empty array(s?) of the right shape

svd & qr: not implemented, due to complex return value rules
  • Loading branch information
eric-wieser committed Mar 4, 2017
1 parent 96c3e66 commit 6627449
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 29 deletions.
6 changes: 6 additions & 0 deletions doc/release/1.13.0-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,12 @@ np.matrix with booleans elements can now be created using the string syntax
``np.matrix`` failed whenever one attempts to use it with booleans, e.g.,
``np.matrix('True')``. Now, this works as expected.

More ``linalg`` operations now accept empty vectors and matrices
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
All of the following functions in ``np.linalg`` now work when given input
arrays with a 0 in the last two dimensions: `det``, ``slogdet``, ``pinv``,
``eigvals``, ``eigvalsh``, ``eig``, ``eigh``.

Changes
=======

Expand Down
38 changes: 29 additions & 9 deletions numpy/linalg/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
csingle, cdouble, inexact, complexfloating, newaxis, ravel, all, Inf, dot,
add, multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size,
finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax, product, abs,
broadcast, atleast_2d, intp, asanyarray, isscalar, object_
broadcast, atleast_2d, intp, asanyarray, isscalar, object_, ones
)
from numpy.core.multiarray import normalize_axis_index
from numpy.lib import triu, asfarray
Expand Down Expand Up @@ -217,9 +217,13 @@ def _assertFinite(*arrays):
if not (isfinite(a).all()):
raise LinAlgError("Array must not contain infs or NaNs")

def _isEmpty2d(arr):
# check size first for efficiency
return arr.size == 0 and product(arr.shape[-2:]) == 0

def _assertNoEmpty2d(*arrays):
for a in arrays:
if a.size == 0 and product(a.shape[-2:]) == 0:
if _isEmpty2d(a):
raise LinAlgError("Arrays cannot be empty")


Expand Down Expand Up @@ -898,11 +902,12 @@ def eigvals(a):
"""
a, wrap = _makearray(a)
_assertNoEmpty2d(a)
_assertRankAtLeast2(a)
_assertNdSquareness(a)
_assertFinite(a)
t, result_t = _commonType(a)
if _isEmpty2d(a):
return empty(a.shape[-1:], dtype=result_t)

extobj = get_linalg_error_extobj(
_raise_linalgerror_eigenvalues_nonconvergence)
Expand Down Expand Up @@ -1002,10 +1007,11 @@ def eigvalsh(a, UPLO='L'):
gufunc = _umath_linalg.eigvalsh_up

a, wrap = _makearray(a)
_assertNoEmpty2d(a)
_assertRankAtLeast2(a)
_assertNdSquareness(a)
t, result_t = _commonType(a)
if _isEmpty2d(a):
return empty(a.shape[-1:], dtype=result_t)
signature = 'D->d' if isComplexType(t) else 'd->d'
w = gufunc(a, signature=signature, extobj=extobj)
return w.astype(_realType(result_t), copy=False)
Expand Down Expand Up @@ -1139,11 +1145,14 @@ def eig(a):
"""
a, wrap = _makearray(a)
_assertNoEmpty2d(a)
_assertRankAtLeast2(a)
_assertNdSquareness(a)
_assertFinite(a)
t, result_t = _commonType(a)
if _isEmpty2d(a):
w = empty(a.shape[-1:], dtype=result_t)
vt = empty(a.shape, dtype=result_t)
return w, wrap(vt)

extobj = get_linalg_error_extobj(
_raise_linalgerror_eigenvalues_nonconvergence)
Expand Down Expand Up @@ -1280,8 +1289,11 @@ def eigh(a, UPLO='L'):
a, wrap = _makearray(a)
_assertRankAtLeast2(a)
_assertNdSquareness(a)
_assertNoEmpty2d(a)
t, result_t = _commonType(a)
if _isEmpty2d(a):
w = empty(a.shape[-1:], dtype=result_t)
vt = empty(a.shape, dtype=result_t)
return w, wrap(vt)

extobj = get_linalg_error_extobj(
_raise_linalgerror_eigenvalues_nonconvergence)
Expand Down Expand Up @@ -1660,7 +1672,9 @@ def pinv(a, rcond=1e-15 ):
"""
a, wrap = _makearray(a)
_assertNoEmpty2d(a)
if _isEmpty2d(a):
res = empty(a.shape[:-2] + (a.shape[-1], a.shape[-2]), dtype=a.dtype)
return wrap(res)
a = a.conjugate()
u, s, vt = svd(a, 0)
m = u.shape[0]
Expand Down Expand Up @@ -1751,11 +1765,15 @@ def slogdet(a):
"""
a = asarray(a)
_assertNoEmpty2d(a)
_assertRankAtLeast2(a)
_assertNdSquareness(a)
t, result_t = _commonType(a)
real_t = _realType(result_t)
if _isEmpty2d(a):
# determinant of empty matrix is 1
sign = ones(a.shape[:-2], dtype=result_t)
logdet = zeros(a.shape[:-2], dtype=real_t)
return sign, logdet
signature = 'D->Dd' if isComplexType(t) else 'd->dd'
sign, logdet = _umath_linalg.slogdet(a, signature=signature)
if isscalar(sign):
Expand Down Expand Up @@ -1816,10 +1834,12 @@ def det(a):
"""
a = asarray(a)
_assertNoEmpty2d(a)
_assertRankAtLeast2(a)
_assertNdSquareness(a)
t, result_t = _commonType(a)
# 0x0 matrices have determinant 1
if _isEmpty2d(a):
return ones(a.shape[:-2], dtype=result_t)
signature = 'D->D' if isComplexType(t) else 'd->d'
r = _umath_linalg.det(a, signature=signature)
if isscalar(r):
Expand Down
25 changes: 5 additions & 20 deletions numpy/linalg/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@ def apply_tag(tag, cases):
array([[2. + 1j, 1. + 2j, 1 + 3j], [1 - 2j, 1 - 3j, 1 - 6j]], dtype=cdouble)),
LinalgCase("0x0",
np.empty((0, 0), dtype=double),
np.empty((0, 0), dtype=double),
np.empty((0,), dtype=double),
tags={'size-0'}),
LinalgCase("0x0_matrix",
np.empty((0, 0), dtype=double).view(np.matrix),
np.empty((0, 1), dtype=double).view(np.matrix),
tags={'size-0'}),
LinalgCase("8x8",
np.random.rand(8, 8),
Expand Down Expand Up @@ -549,9 +553,6 @@ class ArraySubclass(np.ndarray):
class TestEigvals(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):

def do(self, a, b, tags):
if 'size-0' in tags:
assert_raises(LinAlgError, linalg.eigvals, a)
return
ev = linalg.eigvals(a)
evalues, evectors = linalg.eig(a)
assert_almost_equal(ev, evalues)
Expand All @@ -569,10 +570,6 @@ def check(dtype):
class TestEig(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):

def do(self, a, b, tags):
if 'size-0' in tags:
assert_raises(LinAlgError, linalg.eig, a)
return

evalues, evectors = linalg.eig(a)
assert_allclose(dot_generalized(a, evectors),
np.asarray(evectors) * np.asarray(evalues)[..., None, :],
Expand Down Expand Up @@ -667,9 +664,6 @@ def test(self):
class TestPinv(LinalgSquareTestCase, LinalgNonsquareTestCase):

def do(self, a, b, tags):
if 'size-0' in tags:
assert_raises(LinAlgError, linalg.pinv, a)
return
a_ginv = linalg.pinv(a)
# `a @ a_ginv == I` does not hold if a is singular
assert_almost_equal(dot(a, a_ginv).dot(a), a, single_decimal=5, double_decimal=11)
Expand All @@ -679,9 +673,6 @@ def do(self, a, b, tags):
class TestDet(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):

def do(self, a, b, tags):
if 'size-0' in tags:
assert_raises(LinAlgError, linalg.det, a)
return
d = linalg.det(a)
(s, ld) = linalg.slogdet(a)
if asarray(a).dtype.type in (single, double):
Expand Down Expand Up @@ -820,9 +811,6 @@ def test_square(self):
class TestEigvalsh(HermitianTestCase, HermitianGeneralizedTestCase):

def do(self, a, b, tags):
if 'size-0' in tags:
assert_raises(LinAlgError, linalg.eigvalsh, a, 'L')
return
# note that eigenvalue arrays returned by eig must be sorted since
# their order isn't guaranteed.
ev = linalg.eigvalsh(a, 'L')
Expand Down Expand Up @@ -873,9 +861,6 @@ def test_UPLO(self):
class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase):

def do(self, a, b, tags):
if 'size-0' in tags:
assert_raises(LinAlgError, linalg.eigh, a)
return
# note that eigenvalue arrays returned by eig must be sorted since
# their order isn't guaranteed.
ev, evc = linalg.eigh(a)
Expand Down

0 comments on commit 6627449

Please sign in to comment.