Skip to content

Commit

Permalink
Merge pull request scipy#4390 from insertinterestingnamehere/norm_key…
Browse files Browse the repository at this point in the history
…words

ENH: Allow axis and keepdims arguments to be passed to scipy.linalg.norm.
  • Loading branch information
rgommers committed Nov 30, 2015
2 parents f13e808 + a8416dd commit 3fd4d65
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 8 deletions.
40 changes: 33 additions & 7 deletions scipy/linalg/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
__all__ = ['LinAlgError', 'norm']


def norm(a, ord=None):
def norm(a, ord=None, axis=None, keepdims=False):
"""
Matrix or vector norm.
Expand All @@ -19,15 +19,25 @@ def norm(a, ord=None):
Parameters
----------
a : (M,) or (M, N) array_like
Input array.
Input array. If `axis` is None, `a` must be 1-D or 2-D.
ord : {non-zero int, inf, -inf, 'fro'}, optional
Order of the norm (see table under ``Notes``). inf means numpy's
`inf` object.
`inf` object
axis : {int, 2-tuple of ints, None}, optional
If `axis` is an integer, it specifies the axis of `a` along which to
compute the vector norms. If `axis` is a 2-tuple, it specifies the
axes that hold 2-D matrices, and the matrix norms of these matrices
are computed. If `axis` is None then either a vector norm (when `a`
is 1-D) or a matrix norm (when `a` is 2-D) is returned.
keepdims : bool, optional
If this is set to True, the axes which are normed over are left in the
result as dimensions with size one. With this option the result will
broadcast correctly against the original `a`.
Returns
-------
norm : float
Norm of the matrix or vector.
n : float or ndarray
Norm of the matrix or vector(s).
Notes
-----
Expand Down Expand Up @@ -56,6 +66,10 @@ def norm(a, ord=None):
:math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
The ``axis`` and ``keepdims`` arguments are passed directly to
``numpy.linalg.norm`` and are only usable if they are supported
by the version of numpy in use.
References
----------
.. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,
Expand Down Expand Up @@ -113,16 +127,21 @@ def norm(a, ord=None):
"""
# Differs from numpy only in non-finite handling and the use of blas.
a = np.asarray_chkfinite(a)
if a.dtype.char in 'fdFD':

# Only use optimized norms if axis and keepdims are not specified.
if a.dtype.char in 'fdFD' and axis is None and not keepdims:

if ord in (None, 2) and (a.ndim == 1):
# use blas for fast and stable euclidean norm
nrm2 = get_blas_funcs('nrm2', dtype=a.dtype)
return nrm2(a)

if a.ndim == 2:
if a.ndim == 2 and axis is None and not keepdims:
# Use lapack for a couple fast matrix norms.
# For some reason the *lange frobenius norm is slow.
lange_args = None
# Make sure this works if the user uses the axis keywords
# to apply the norm to the transpose.
if ord == 1:
if np.isfortran(a):
lange_args = '1', a
Expand All @@ -137,6 +156,13 @@ def norm(a, ord=None):
lange = get_lapack_funcs('lange', dtype=a.dtype)
return lange(*lange_args)

# Filter out the axis and keepdims arguments if they aren't used so they
# are never inadvertently passed to a version of numpy that doesn't
# support them.
if axis is not None:
if keepdims:
return np.linalg.norm(a, ord=ord, axis=axis, keepdims=keepdims)
return np.linalg.norm(a, ord=ord, axis=axis)
return np.linalg.norm(a, ord=ord)


Expand Down
44 changes: 43 additions & 1 deletion scipy/linalg/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@

from numpy.testing import (TestCase, rand, run_module_suite, assert_raises,
assert_equal, assert_almost_equal, assert_array_almost_equal, assert_,
assert_allclose, assert_array_equal)
assert_allclose, assert_array_equal, dec)

from scipy.linalg import (solve, inv, det, lstsq, pinv, pinv2, pinvh, norm,
solve_banded, solveh_banded, solve_triangular, solve_circulant,
circulant, LinAlgError)

from scipy.linalg._testutils import assert_no_overwrite

from scipy._lib._version import NumpyVersion

REAL_DTYPES = [np.float32, np.float64]
COMPLEX_DTYPES = [np.complex64, np.complex128]
DTYPES = REAL_DTYPES + COMPLEX_DTYPES
Expand Down Expand Up @@ -1072,6 +1074,20 @@ def test_zero_norm(self):
assert_equal(norm([1,0,3], 0), 2)
assert_equal(norm([1,2,3], 0), 3)

@dec.skipif(NumpyVersion(np.__version__) < '1.7.0')
def test_axis_kwd(self):
a = np.array([[[2, 1], [3, 4]]] * 2, 'd')
assert_allclose(norm(a, axis=1), [[3.60555128, 4.12310563]] * 2)
assert_allclose(norm(a, 1, axis=1), [[5.] * 2] * 2)

@dec.skipif(NumpyVersion(np.__version__) < '1.10.0')
def test_keepdims_kwd(self):
a = np.array([[[2, 1], [3, 4]]] * 2, 'd')
b = norm(a, axis=1, keepdims=True)
assert_allclose(b, [[[3.60555128, 4.12310563]]] * 2)
assert_(b.shape == (2, 1, 2))
assert_allclose(norm(a, 1, axis=2, keepdims=True), [[[3.],[7.]]] * 2)


class TestMatrixNorms(object):

Expand All @@ -1095,6 +1111,32 @@ def test_matrix_norms(self):
desired = np.linalg.norm(A.astype(t_high), ord=order)
np.assert_allclose(actual, desired)

@dec.skipif(NumpyVersion(np.__version__) < '1.7.0')
def test_axis_kwd(self):
a = np.array([[[2, 1], [3, 4]]] * 2, 'd')
b = norm(a, ord=np.inf, axis=(1, 0))
c = norm(np.swapaxes(a, 0, 1), ord=np.inf, axis=(0, 1))
d = norm(a, ord=1, axis=(0, 1))
assert_allclose(b, c)
assert_allclose(c, d)
assert_allclose(b, d)
assert_(b.shape == c.shape == d.shape)
b = norm(a, ord=1, axis=(1, 0))
c = norm(np.swapaxes(a, 0, 1), ord=1, axis=(0, 1))
d = norm(a, ord=np.inf, axis=(0, 1))
assert_allclose(b, c)
assert_allclose(c, d)
assert_allclose(b, d)
assert_(b.shape == c.shape == d.shape)

@dec.skipif(NumpyVersion(np.__version__) < '1.10.0')
def test_keepdims_kwd(self):
a = np.arange(120, dtype='d').reshape(2, 3, 4, 5)
b = norm(a, ord=np.inf, axis=(1, 0), keepdims=True)
c = norm(a, ord=1, axis=(0, 1), keepdims=True)
assert_allclose(b, c)
assert_(b.shape == c.shape)


class TestOverwrite(object):
def test_solve(self):
Expand Down

0 comments on commit 3fd4d65

Please sign in to comment.