diff --git a/scipy/linalg/misc.py b/scipy/linalg/misc.py index be1f80ab5317..ad55e4795942 100644 --- a/scipy/linalg/misc.py +++ b/scipy/linalg/misc.py @@ -8,7 +8,7 @@ __all__ = ['LinAlgError', 'norm'] -def norm(a, ord=None): +def norm(a, ord=None, axis=None, keepdims=False): """ Matrix or vector norm. @@ -19,15 +19,25 @@ def norm(a, ord=None): Parameters ---------- a : (M,) or (M, N) array_like - Input array. + Input array. If `axis` is None, `a` must be 1-D or 2-D. ord : {non-zero int, inf, -inf, 'fro'}, optional Order of the norm (see table under ``Notes``). inf means numpy's - `inf` object. + `inf` object + axis : {int, 2-tuple of ints, None}, optional + If `axis` is an integer, it specifies the axis of `a` along which to + compute the vector norms. If `axis` is a 2-tuple, it specifies the + axes that hold 2-D matrices, and the matrix norms of these matrices + are computed. If `axis` is None then either a vector norm (when `a` + is 1-D) or a matrix norm (when `a` is 2-D) is returned. + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `a`. Returns ------- - norm : float - Norm of the matrix or vector. + n : float or ndarray + Norm of the matrix or vector(s). Notes ----- @@ -56,6 +66,10 @@ def norm(a, ord=None): :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + The ``axis`` and ``keepdims`` arguments are passed directly to + ``numpy.linalg.norm`` and are only usable if they are supported + by the version of numpy in use. + References ---------- .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, @@ -113,16 +127,21 @@ def norm(a, ord=None): """ # Differs from numpy only in non-finite handling and the use of blas. a = np.asarray_chkfinite(a) - if a.dtype.char in 'fdFD': + + # Only use optimized norms if axis and keepdims are not specified. + if a.dtype.char in 'fdFD' and axis is None and not keepdims: + if ord in (None, 2) and (a.ndim == 1): # use blas for fast and stable euclidean norm nrm2 = get_blas_funcs('nrm2', dtype=a.dtype) return nrm2(a) - if a.ndim == 2: + if a.ndim == 2 and axis is None and not keepdims: # Use lapack for a couple fast matrix norms. # For some reason the *lange frobenius norm is slow. lange_args = None + # Make sure this works if the user uses the axis keywords + # to apply the norm to the transpose. if ord == 1: if np.isfortran(a): lange_args = '1', a @@ -137,6 +156,13 @@ def norm(a, ord=None): lange = get_lapack_funcs('lange', dtype=a.dtype) return lange(*lange_args) + # Filter out the axis and keepdims arguments if they aren't used so they + # are never inadvertently passed to a version of numpy that doesn't + # support them. + if axis is not None: + if keepdims: + return np.linalg.norm(a, ord=ord, axis=axis, keepdims=keepdims) + return np.linalg.norm(a, ord=ord, axis=axis) return np.linalg.norm(a, ord=ord) diff --git a/scipy/linalg/tests/test_basic.py b/scipy/linalg/tests/test_basic.py index 67c59cd90afa..8b4ecf4eb001 100644 --- a/scipy/linalg/tests/test_basic.py +++ b/scipy/linalg/tests/test_basic.py @@ -28,7 +28,7 @@ from numpy.testing import (TestCase, rand, run_module_suite, assert_raises, assert_equal, assert_almost_equal, assert_array_almost_equal, assert_, - assert_allclose, assert_array_equal) + assert_allclose, assert_array_equal, dec) from scipy.linalg import (solve, inv, det, lstsq, pinv, pinv2, pinvh, norm, solve_banded, solveh_banded, solve_triangular, solve_circulant, @@ -36,6 +36,8 @@ from scipy.linalg._testutils import assert_no_overwrite +from scipy._lib._version import NumpyVersion + REAL_DTYPES = [np.float32, np.float64] COMPLEX_DTYPES = [np.complex64, np.complex128] DTYPES = REAL_DTYPES + COMPLEX_DTYPES @@ -1072,6 +1074,20 @@ def test_zero_norm(self): assert_equal(norm([1,0,3], 0), 2) assert_equal(norm([1,2,3], 0), 3) + @dec.skipif(NumpyVersion(np.__version__) < '1.7.0') + def test_axis_kwd(self): + a = np.array([[[2, 1], [3, 4]]] * 2, 'd') + assert_allclose(norm(a, axis=1), [[3.60555128, 4.12310563]] * 2) + assert_allclose(norm(a, 1, axis=1), [[5.] * 2] * 2) + + @dec.skipif(NumpyVersion(np.__version__) < '1.10.0') + def test_keepdims_kwd(self): + a = np.array([[[2, 1], [3, 4]]] * 2, 'd') + b = norm(a, axis=1, keepdims=True) + assert_allclose(b, [[[3.60555128, 4.12310563]]] * 2) + assert_(b.shape == (2, 1, 2)) + assert_allclose(norm(a, 1, axis=2, keepdims=True), [[[3.],[7.]]] * 2) + class TestMatrixNorms(object): @@ -1095,6 +1111,32 @@ def test_matrix_norms(self): desired = np.linalg.norm(A.astype(t_high), ord=order) np.assert_allclose(actual, desired) + @dec.skipif(NumpyVersion(np.__version__) < '1.7.0') + def test_axis_kwd(self): + a = np.array([[[2, 1], [3, 4]]] * 2, 'd') + b = norm(a, ord=np.inf, axis=(1, 0)) + c = norm(np.swapaxes(a, 0, 1), ord=np.inf, axis=(0, 1)) + d = norm(a, ord=1, axis=(0, 1)) + assert_allclose(b, c) + assert_allclose(c, d) + assert_allclose(b, d) + assert_(b.shape == c.shape == d.shape) + b = norm(a, ord=1, axis=(1, 0)) + c = norm(np.swapaxes(a, 0, 1), ord=1, axis=(0, 1)) + d = norm(a, ord=np.inf, axis=(0, 1)) + assert_allclose(b, c) + assert_allclose(c, d) + assert_allclose(b, d) + assert_(b.shape == c.shape == d.shape) + + @dec.skipif(NumpyVersion(np.__version__) < '1.10.0') + def test_keepdims_kwd(self): + a = np.arange(120, dtype='d').reshape(2, 3, 4, 5) + b = norm(a, ord=np.inf, axis=(1, 0), keepdims=True) + c = norm(a, ord=1, axis=(0, 1), keepdims=True) + assert_allclose(b, c) + assert_(b.shape == c.shape) + class TestOverwrite(object): def test_solve(self):