Skip to content

Commit

Permalink
Merge pull request numpy#10571 from pv/fix-cond
Browse files Browse the repository at this point in the history
BUG: Fix corner-case behavior of cond() and use SVD when possible
  • Loading branch information
charris authored Feb 13, 2018
2 parents 71962fc + ab8603e commit 02355b4
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 31 deletions.
37 changes: 33 additions & 4 deletions numpy/linalg/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
add, multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size,
finfo, errstate, geterrobj, longdouble, moveaxis, amin, amax, product, abs,
broadcast, atleast_2d, intp, asanyarray, object_, ones, matmul,
swapaxes, divide, count_nonzero
swapaxes, divide, count_nonzero, ndarray, isnan
)
from numpy.core.multiarray import normalize_axis_index
from numpy.lib import triu, asfarray
Expand Down Expand Up @@ -1538,11 +1538,40 @@ def cond(x, p=None):
"""
x = asarray(x) # in case we have a matrix
if p is None:
if p is None or p == 2 or p == -2:
s = svd(x, compute_uv=False)
return s[..., 0]/s[..., -1]
with errstate(all='ignore'):
if p == -2:
r = s[..., -1] / s[..., 0]
else:
r = s[..., 0] / s[..., -1]
else:
return norm(x, p, axis=(-2, -1)) * norm(inv(x), p, axis=(-2, -1))
# Call inv(x) ignoring errors. The result array will
# contain nans in the entries where inversion failed.
_assertRankAtLeast2(x)
_assertNdSquareness(x)
t, result_t = _commonType(x)
signature = 'D->D' if isComplexType(t) else 'd->d'
with errstate(all='ignore'):
invx = _umath_linalg.inv(x, signature=signature)
r = norm(x, p, axis=(-2, -1)) * norm(invx, p, axis=(-2, -1))
r = r.astype(result_t, copy=False)

# Convert nans to infs unless the original array had nan entries
r = asarray(r)
nan_mask = isnan(r)
if nan_mask.any():
nan_mask &= ~isnan(x).any(axis=(-2, -1))
if r.ndim > 0:
r[nan_mask] = Inf
elif nan_mask:
r[()] = Inf

# Convention is to return scalars instead of 0d arrays
if r.ndim == 0:
r = r[()]

return r


def matrix_rank(M, tol=None, hermitian=False):
Expand Down
121 changes: 94 additions & 27 deletions numpy/linalg/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,45 +671,112 @@ def test_0_size(self):
assert_raises(linalg.LinAlgError, linalg.svd, a)


class TestCondSVD(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
class TestCond(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
# cond(x, p) for p in (None, 2, -2)

def do(self, a, b, tags):
c = asarray(a) # a might be a matrix
if 'size-0' in tags:
assert_raises(LinAlgError, linalg.svd, c, compute_uv=False)
assert_raises(LinAlgError, linalg.cond, c)
return

# +-2 norms
s = linalg.svd(c, compute_uv=False)
assert_almost_equal(
s[..., 0] / s[..., -1], linalg.cond(a),
linalg.cond(a), s[..., 0] / s[..., -1],
single_decimal=5, double_decimal=11)

def test_stacked_arrays_explicitly(self):
A = np.array([[1., 2., 1.], [0, -2., 0], [6., 2., 3.]])
assert_equal(linalg.cond(A), linalg.cond(A[None, ...])[0])


class TestCond2(LinalgSquareTestCase):

def do(self, a, b, tags):
c = asarray(a) # a might be a matrix
if 'size-0' in tags:
assert_raises(LinAlgError, linalg.svd, c, compute_uv=False)
return
s = linalg.svd(c, compute_uv=False)
assert_almost_equal(
s[..., 0] / s[..., -1], linalg.cond(a, 2),
linalg.cond(a, 2), s[..., 0] / s[..., -1],
single_decimal=5, double_decimal=11)
assert_almost_equal(
linalg.cond(a, -2), s[..., -1] / s[..., 0],
single_decimal=5, double_decimal=11)

def test_stacked_arrays_explicitly(self):
A = np.array([[1., 2., 1.], [0, -2., 0], [6., 2., 3.]])
assert_equal(linalg.cond(A, 2), linalg.cond(A[None, ...], 2)[0])


class TestCondInf(object):
# Other norms
cinv = np.linalg.inv(c)
assert_almost_equal(
linalg.cond(a, 1),
abs(c).sum(-2).max(-1) * abs(cinv).sum(-2).max(-1),
single_decimal=5, double_decimal=11)
assert_almost_equal(
linalg.cond(a, -1),
abs(c).sum(-2).min(-1) * abs(cinv).sum(-2).min(-1),
single_decimal=5, double_decimal=11)
assert_almost_equal(
linalg.cond(a, np.inf),
abs(c).sum(-1).max(-1) * abs(cinv).sum(-1).max(-1),
single_decimal=5, double_decimal=11)
assert_almost_equal(
linalg.cond(a, -np.inf),
abs(c).sum(-1).min(-1) * abs(cinv).sum(-1).min(-1),
single_decimal=5, double_decimal=11)
assert_almost_equal(
linalg.cond(a, 'fro'),
np.sqrt((abs(c)**2).sum(-1).sum(-1)
* (abs(cinv)**2).sum(-1).sum(-1)),
single_decimal=5, double_decimal=11)

def test(self):
A = array([[1., 0, 0], [0, -2., 0], [0, 0, 3.]])
assert_almost_equal(linalg.cond(A, inf), 3.)
def test_basic_nonsvd(self):
# Smoketest the non-svd norms
A = array([[1., 0, 1], [0, -2., 0], [0, 0, 3.]])
assert_almost_equal(linalg.cond(A, inf), 4)
assert_almost_equal(linalg.cond(A, -inf), 2/3)
assert_almost_equal(linalg.cond(A, 1), 4)
assert_almost_equal(linalg.cond(A, -1), 0.5)
assert_almost_equal(linalg.cond(A, 'fro'), np.sqrt(265 / 12))

def test_singular(self):
# Singular matrices have infinite condition number for
# positive norms, and negative norms shouldn't raise
# exceptions
As = [np.zeros((2, 2)), np.ones((2, 2))]
p_pos = [None, 1, 2, 'fro']
p_neg = [-1, -2]
for A, p in itertools.product(As, p_pos):
# Inversion may not hit exact infinity, so just check the
# number is large
assert_(linalg.cond(A, p) > 1e15)
for A, p in itertools.product(As, p_neg):
linalg.cond(A, p)

def test_nan(self):
# nans should be passed through, not converted to infs
ps = [None, 1, -1, 2, -2, 'fro']
p_pos = [None, 1, 2, 'fro']

A = np.ones((2, 2))
A[0,1] = np.nan
for p in ps:
c = linalg.cond(A, p)
assert_(isinstance(c, np.float_))
assert_(np.isnan(c))

A = np.ones((3, 2, 2))
A[1,0,1] = np.nan
for p in ps:
c = linalg.cond(A, p)
assert_(np.isnan(c[1]))
if p in p_pos:
assert_(c[0] > 1e15)
assert_(c[2] > 1e15)
else:
assert_(not np.isnan(c[0]))
assert_(not np.isnan(c[2]))

def test_stacked_singular(self):
# Check behavior when only some of the stacked matrices are
# singular
np.random.seed(1234)
A = np.random.rand(2, 2, 2, 2)
A[0,0] = 0
A[1,1] = 0

for p in (None, 1, 2, 'fro', -1, -2):
c = linalg.cond(A, p)
assert_equal(c[0,0], np.inf)
assert_equal(c[1,1], np.inf)
assert_(np.isfinite(c[0,1]))
assert_(np.isfinite(c[1,0]))


class TestPinv(LinalgSquareTestCase,
Expand Down

0 comments on commit 02355b4

Please sign in to comment.