From c12db309aad07677916d8d30a959e89b99cb49ef Mon Sep 17 00:00:00 2001 From: "Gregory R. Lee" Date: Wed, 18 Jan 2017 17:41:27 -0500 Subject: [PATCH 1/6] ENH: add dctn, idctn to fftpack --- scipy/fftpack/__init__.py | 4 +- scipy/fftpack/realtransforms.py | 110 +++++++++++++++++++- scipy/fftpack/tests/test_real_transforms.py | 61 ++++++++++- 3 files changed, 171 insertions(+), 4 deletions(-) diff --git a/scipy/fftpack/__init__.py b/scipy/fftpack/__init__.py index 51ff7838cfbf..6f7bc61b25cf 100644 --- a/scipy/fftpack/__init__.py +++ b/scipy/fftpack/__init__.py @@ -19,6 +19,8 @@ irfft - Inverse of rfft dct - Discrete cosine transform idct - Inverse discrete cosine transform + dctn - n-dimensional Discrete cosine transform + idctn - n-dimensional Inverse discrete cosine transform dst - Discrete sine transform idst - Inverse discrete sine transform @@ -102,7 +104,7 @@ del k, register_func from .realtransforms import * -__all__.extend(['dct', 'idct', 'dst', 'idst']) +__all__.extend(['dct', 'idct', 'dst', 'idst', 'dctn', 'idctn']) from scipy._lib._testutils import PytestTester test = PytestTester(__name__) diff --git a/scipy/fftpack/realtransforms.py b/scipy/fftpack/realtransforms.py index 0df1883a0175..1443987f75d6 100644 --- a/scipy/fftpack/realtransforms.py +++ b/scipy/fftpack/realtransforms.py @@ -4,7 +4,7 @@ from __future__ import division, print_function, absolute_import -__all__ = ['dct', 'idct', 'dst', 'idst'] +__all__ = ['dct', 'idct', 'dst', 'idst', 'dctn', 'idctn'] import numpy as np from scipy.fftpack import _fftpack @@ -22,6 +22,114 @@ atexit.register(_fftpack.destroy_dst2_cache) +def dctn(x, type=2, n=None, axes=None, norm=None, overwrite_x=False): + """ + Return multidimensional Discrete Cosine Transform of x along the specified + axes. + + Parameters + ---------- + x : array_like + The input array. + type : {1, 2, 3}, optional + Type of the DCT (see Notes). Default type is 2. + n : int, optional + Length of the transform. If ``n < x.shape[axis]``, `x` is + truncated. If ``n > x.shape[axis]``, `x` is zero-padded. The + default results in ``n = x.shape[axis]``. + axes : tuple or None, optional + Axes along which the DCT is computed; the default is over all axes. + norm : {None, 'ortho'}, optional + Normalization mode (see Notes). Default is None. + overwrite_x : bool, optional + If True, the contents of `x` can be destroyed; the default is False. + + Returns + ------- + y : ndarray of real + The transformed input array. + + See Also + -------- + idctn : Inverse multidimensional DCT + + Notes + ----- + For full details of the DCT types and normalization modes, as well as + references, see `dct`. + + Examples + -------- + >>> from scipy.fftpack import dctn, idctn + >>> y = np.random.randn(16, 16) + >>> np.allclose(y, idctn(dctn(y, norm='ortho'), norm='ortho')) + True + + """ + x = np.asanyarray(x) + if axes is None: + axes = np.arange(x.ndim) + if np.isscalar(axes): + axes = [axes, ] + for ax in axes: + x = dct(x, type=type, n=n, axis=ax, norm=norm, overwrite_x=overwrite_x) + return x + + +def idctn(x, type=2, n=None, axes=None, norm=None, overwrite_x=False): + """ + Return multidimensional Discrete Cosine Transform of x along the specified + axes. + + Parameters + ---------- + x : array_like + The input array. + type : {1, 2, 3}, optional + Type of the DCT (see Notes). Default type is 2. + n : int, optional + Length of the transform. If ``n < x.shape[axis]``, `x` is + truncated. If ``n > x.shape[axis]``, `x` is zero-padded. The + default results in ``n = x.shape[axis]``. + axes : tuple or None, optional + Axes along which the IDCT is computed; the default is over all axes. + norm : {None, 'ortho'}, optional + Normalization mode (see Notes). Default is None. + overwrite_x : bool, optional + If True, the contents of `x` can be destroyed; the default is False. + + Returns + ------- + y : ndarray of real + The transformed input array. + + See Also + -------- + dctn : multidimensional DCT + + Notes + ----- + For full details of the IDCT types and normalization modes, as well as + references, see `idct`. + + Examples + -------- + >>> from scipy.fftpack import dctn, idctn + >>> y = np.random.randn(16, 16) + >>> np.allclose(y, idctn(dctn(y, norm='ortho'), norm='ortho')) + True + """ + x = np.asanyarray(x) + if axes is None: + axes = np.arange(x.ndim) + if np.isscalar(axes): + axes = [axes, ] + for ax in axes: + x = idct(x, type=type, n=n, axis=ax, norm=norm, + overwrite_x=overwrite_x) + return x + + def dct(x, type=2, n=None, axis=-1, norm=None, overwrite_x=False): """ Return the Discrete Cosine Transform of arbitrary type sequence x. diff --git a/scipy/fftpack/tests/test_real_transforms.py b/scipy/fftpack/tests/test_real_transforms.py index 8fb8b97feb95..9f371a929e56 100644 --- a/scipy/fftpack/tests/test_real_transforms.py +++ b/scipy/fftpack/tests/test_real_transforms.py @@ -5,7 +5,7 @@ import numpy as np from numpy.testing import assert_array_almost_equal, assert_equal -from scipy.fftpack.realtransforms import dct, idct, dst, idst +from scipy.fftpack.realtransforms import dct, idct, dst, idst, dctn, idctn # Matlab reference data MDATA = np.load(join(dirname(__file__), 'test.npz')) @@ -48,7 +48,27 @@ def fftw_dst_ref(type, size, dt): return x, y, dt -class TestComplex(object): +def dct_2d_ref(x, **kwargs): + """ used as a reference in testing dct2. """ + x = np.array(x, copy=True) + for row in range(x.shape[0]): + x[row, :] = dct(x[row, :], **kwargs) + for col in range(x.shape[1]): + x[:, col] = dct(x[:, col], **kwargs) + return x + + +def idct_2d_ref(x, **kwargs): + """ used as a reference in testing idct2. """ + x = np.array(x, copy=True) + for row in range(x.shape[0]): + x[row, :] = idct(x[row, :], **kwargs) + for col in range(x.shape[1]): + x[:, col] = idct(x[:, col], **kwargs) + return x + + +class TestComplex(TestCase): def test_dct_complex64(self): y = dct(1j*np.arange(5, dtype=np.complex64)) x = 1j*dct(np.arange(5)) @@ -521,3 +541,40 @@ def test_idst(self): self._check_1d(idst, dtype, (16, 2), 0, overwritable) self._check_1d(idst, dtype, (2, 16), 1, overwritable) + +class Test_DCTN_IDCTN(TestCase): + dec = 14 + types = [1, 2, 3] + norms = [None, 'ortho'] + rstate = np.random.RandomState(1234) + shape = (32, 16) + data = rstate.randn(*shape) + + def test_axes_round_trip(self): + norm = 'ortho' + for axes in [None, (1, ), (0, ), (0, 1), (-2, -1)]: + for dct_type in self.types: + if norm == 'ortho' and dct_type == 1: + continue # 'ortho' not supported by DCT-I + tmp = dctn(self.data, type=dct_type, axes=axes, norm=norm) + tmp = idctn(tmp, type=dct_type, axes=axes, norm=norm) + assert_array_almost_equal(self.data, tmp, decimal=self.dec) + + def test_dctn_vs_2d_reference(self): + for dct_type in self.types: + for norm in self.norms: + if norm == 'ortho' and dct_type == 1: + continue # 'ortho' not supported by DCT-I + y1 = dctn(self.data, type=dct_type, axes=None, norm=norm) + y2 = dct_2d_ref(self.data, type=dct_type, norm=norm) + assert_array_almost_equal(y1, y2, decimal=11) + + def test_idctn_vs_2d_reference(self): + for dct_type in self.types: + for norm in self.norms: + if norm == 'ortho' and dct_type == 1: + continue # 'ortho' not supported by DCT-I + fdata = dctn(self.data, type=dct_type, norm=norm) + y1 = idctn(fdata, type=dct_type, norm=norm) + y2 = idct_2d_ref(fdata, type=dct_type, norm=norm) + assert_array_almost_equal(y1, y2, decimal=11) From e9fa6f0b719bdb49c0d41d4d71affc5912bd4570 Mon Sep 17 00:00:00 2001 From: "Gregory R. Lee" Date: Mon, 21 Aug 2017 13:13:48 -0400 Subject: [PATCH 2/6] dctn, idctn: raise error on non-unique axes --- scipy/fftpack/realtransforms.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scipy/fftpack/realtransforms.py b/scipy/fftpack/realtransforms.py index 1443987f75d6..8684da60863d 100644 --- a/scipy/fftpack/realtransforms.py +++ b/scipy/fftpack/realtransforms.py @@ -71,6 +71,8 @@ def dctn(x, type=2, n=None, axes=None, norm=None, overwrite_x=False): axes = np.arange(x.ndim) if np.isscalar(axes): axes = [axes, ] + if len(np.unique(axes)) != len(axes): + raise ValueError("All axes must be unique.") for ax in axes: x = dct(x, type=type, n=n, axis=ax, norm=norm, overwrite_x=overwrite_x) return x @@ -124,6 +126,8 @@ def idctn(x, type=2, n=None, axes=None, norm=None, overwrite_x=False): axes = np.arange(x.ndim) if np.isscalar(axes): axes = [axes, ] + if len(np.unique(axes)) != len(axes): + raise ValueError("All axes must be unique.") for ax in axes: x = idct(x, type=type, n=n, axis=ax, norm=norm, overwrite_x=overwrite_x) From e3fdcd6df9a279c11c5ccb9f6a73f78233de27d5 Mon Sep 17 00:00:00 2001 From: "Gregory R. Lee" Date: Mon, 21 Aug 2017 13:28:04 -0400 Subject: [PATCH 3/6] dctn, idctn: replace argument 'n' by 'shape' for consistency with fftn and ifftn. duplicate axes result in an error. --- scipy/fftpack/realtransforms.py | 70 +++++++++++++++++++++++++-------- 1 file changed, 54 insertions(+), 16 deletions(-) diff --git a/scipy/fftpack/realtransforms.py b/scipy/fftpack/realtransforms.py index 8684da60863d..f0b4f1b6a239 100644 --- a/scipy/fftpack/realtransforms.py +++ b/scipy/fftpack/realtransforms.py @@ -22,7 +22,7 @@ atexit.register(_fftpack.destroy_dst2_cache) -def dctn(x, type=2, n=None, axes=None, norm=None, overwrite_x=False): +def dctn(x, type=2, shape=None, axes=None, norm=None, overwrite_x=False): """ Return multidimensional Discrete Cosine Transform of x along the specified axes. @@ -33,10 +33,13 @@ def dctn(x, type=2, n=None, axes=None, norm=None, overwrite_x=False): The input array. type : {1, 2, 3}, optional Type of the DCT (see Notes). Default type is 2. - n : int, optional - Length of the transform. If ``n < x.shape[axis]``, `x` is - truncated. If ``n > x.shape[axis]``, `x` is zero-padded. The - default results in ``n = x.shape[axis]``. + shape : tuple of ints, optional + The shape of the result. If both `shape` and `axes` (see below) are + None, `shape` is ``x.shape``; if `shape` is None but `axes` is + not None, then `shape` is ``scipy.take(x.shape, axes, axis=0)``. + If ``shape[i] > x.shape[i]``, the i-th dimension is padded with zeros. + If ``shape[i] < x.shape[i]``, the i-th dimension is truncated to + length ``shape[i]``. axes : tuple or None, optional Axes along which the DCT is computed; the default is over all axes. norm : {None, 'ortho'}, optional @@ -67,18 +70,34 @@ def dctn(x, type=2, n=None, axes=None, norm=None, overwrite_x=False): """ x = np.asanyarray(x) + + if shape is None: + if axes is None: + shape = x.shape + else: + shape = np.take(x.shape, axes) + shape = tuple(shape) + for dim in shape: + if dim < 1: + raise ValueError("Invalid number of DCT data points " + "(%s) specified." % (shape,)) + if axes is None: - axes = np.arange(x.ndim) - if np.isscalar(axes): + axes = list(range(-x.ndim, 0)) + elif np.isscalar(axes): axes = [axes, ] + if len(axes) != len(shape): + raise ValueError("when given, axes and shape arguments " + "have to be of the same length") if len(np.unique(axes)) != len(axes): raise ValueError("All axes must be unique.") - for ax in axes: + + for n, ax in zip(shape, axes): x = dct(x, type=type, n=n, axis=ax, norm=norm, overwrite_x=overwrite_x) return x -def idctn(x, type=2, n=None, axes=None, norm=None, overwrite_x=False): +def idctn(x, type=2, shape=None, axes=None, norm=None, overwrite_x=False): """ Return multidimensional Discrete Cosine Transform of x along the specified axes. @@ -89,10 +108,13 @@ def idctn(x, type=2, n=None, axes=None, norm=None, overwrite_x=False): The input array. type : {1, 2, 3}, optional Type of the DCT (see Notes). Default type is 2. - n : int, optional - Length of the transform. If ``n < x.shape[axis]``, `x` is - truncated. If ``n > x.shape[axis]``, `x` is zero-padded. The - default results in ``n = x.shape[axis]``. + shape : tuple of ints, optional + The shape of the result. If both `shape` and `axes` (see below) are + None, `shape` is ``x.shape``; if `shape` is None but `axes` is + not None, then `shape` is ``scipy.take(x.shape, axes, axis=0)``. + If ``shape[i] > x.shape[i]``, the i-th dimension is padded with zeros. + If ``shape[i] < x.shape[i]``, the i-th dimension is truncated to + length ``shape[i]``. axes : tuple or None, optional Axes along which the IDCT is computed; the default is over all axes. norm : {None, 'ortho'}, optional @@ -122,13 +144,29 @@ def idctn(x, type=2, n=None, axes=None, norm=None, overwrite_x=False): True """ x = np.asanyarray(x) + + if shape is None: + if axes is None: + shape = x.shape + else: + shape = np.take(x.shape, axes) + shape = tuple(shape) + for dim in shape: + if dim < 1: + raise ValueError("Invalid number of DCT data points " + "(%s) specified." % (shape,)) + if axes is None: - axes = np.arange(x.ndim) - if np.isscalar(axes): + axes = list(range(-x.ndim, 0)) + elif np.isscalar(axes): axes = [axes, ] + if len(axes) != len(shape): + raise ValueError("when given, axes and shape arguments " + "have to be of the same length") if len(np.unique(axes)) != len(axes): raise ValueError("All axes must be unique.") - for ax in axes: + + for n, ax in zip(shape, axes): x = idct(x, type=type, n=n, axis=ax, norm=norm, overwrite_x=overwrite_x) return x From 0e07113bcc3aaa712c2d15a69868605a77ecb0cc Mon Sep 17 00:00:00 2001 From: "Gregory R. Lee" Date: Mon, 21 Aug 2017 13:31:24 -0400 Subject: [PATCH 4/6] move shape and axes checks into a separate helper function --- scipy/fftpack/realtransforms.py | 72 +++++++++++++-------------------- 1 file changed, 28 insertions(+), 44 deletions(-) diff --git a/scipy/fftpack/realtransforms.py b/scipy/fftpack/realtransforms.py index f0b4f1b6a239..731d7a4c3f89 100644 --- a/scipy/fftpack/realtransforms.py +++ b/scipy/fftpack/realtransforms.py @@ -22,6 +22,32 @@ atexit.register(_fftpack.destroy_dst2_cache) +def _init_nd_shape_and_axes(x, shape, axes): + """Handle shape and axes arguments for dctn, idctn, dstn, idstn.""" + if shape is None: + if axes is None: + shape = x.shape + else: + shape = np.take(x.shape, axes) + shape = tuple(shape) + for dim in shape: + if dim < 1: + raise ValueError("Invalid number of DCT data points " + "(%s) specified." % (shape,)) + + if axes is None: + axes = list(range(-x.ndim, 0)) + elif np.isscalar(axes): + axes = [axes, ] + if len(axes) != len(shape): + raise ValueError("when given, axes and shape arguments " + "have to be of the same length") + if len(np.unique(axes)) != len(axes): + raise ValueError("All axes must be unique.") + + return shape, axes + + def dctn(x, type=2, shape=None, axes=None, norm=None, overwrite_x=False): """ Return multidimensional Discrete Cosine Transform of x along the specified @@ -70,28 +96,7 @@ def dctn(x, type=2, shape=None, axes=None, norm=None, overwrite_x=False): """ x = np.asanyarray(x) - - if shape is None: - if axes is None: - shape = x.shape - else: - shape = np.take(x.shape, axes) - shape = tuple(shape) - for dim in shape: - if dim < 1: - raise ValueError("Invalid number of DCT data points " - "(%s) specified." % (shape,)) - - if axes is None: - axes = list(range(-x.ndim, 0)) - elif np.isscalar(axes): - axes = [axes, ] - if len(axes) != len(shape): - raise ValueError("when given, axes and shape arguments " - "have to be of the same length") - if len(np.unique(axes)) != len(axes): - raise ValueError("All axes must be unique.") - + shape, axes = _init_nd_shape_and_axes(x, shape, axes) for n, ax in zip(shape, axes): x = dct(x, type=type, n=n, axis=ax, norm=norm, overwrite_x=overwrite_x) return x @@ -144,28 +149,7 @@ def idctn(x, type=2, shape=None, axes=None, norm=None, overwrite_x=False): True """ x = np.asanyarray(x) - - if shape is None: - if axes is None: - shape = x.shape - else: - shape = np.take(x.shape, axes) - shape = tuple(shape) - for dim in shape: - if dim < 1: - raise ValueError("Invalid number of DCT data points " - "(%s) specified." % (shape,)) - - if axes is None: - axes = list(range(-x.ndim, 0)) - elif np.isscalar(axes): - axes = [axes, ] - if len(axes) != len(shape): - raise ValueError("when given, axes and shape arguments " - "have to be of the same length") - if len(np.unique(axes)) != len(axes): - raise ValueError("All axes must be unique.") - + shape, axes = _init_nd_shape_and_axes(x, shape, axes) for n, ax in zip(shape, axes): x = idct(x, type=type, n=n, axis=ax, norm=norm, overwrite_x=overwrite_x) From 7d3fb706e0216e5bb778d638762de9744941cf83 Mon Sep 17 00:00:00 2001 From: "Gregory R. Lee" Date: Mon, 21 Aug 2017 13:45:13 -0400 Subject: [PATCH 5/6] add dstn, idstn --- scipy/fftpack/realtransforms.py | 108 ++++++++++++++++++++ scipy/fftpack/tests/test_real_transforms.py | 87 +++++++++++----- 2 files changed, 172 insertions(+), 23 deletions(-) diff --git a/scipy/fftpack/realtransforms.py b/scipy/fftpack/realtransforms.py index 731d7a4c3f89..6707fd6817b3 100644 --- a/scipy/fftpack/realtransforms.py +++ b/scipy/fftpack/realtransforms.py @@ -156,6 +156,114 @@ def idctn(x, type=2, shape=None, axes=None, norm=None, overwrite_x=False): return x +def dstn(x, type=2, shape=None, axes=None, norm=None, overwrite_x=False): + """ + Return multidimensional Discrete Sine Transform of x along the specified + axes. + + Parameters + ---------- + x : array_like + The input array. + type : {1, 2, 3}, optional + Type of the DCT (see Notes). Default type is 2. + shape : tuple of ints, optional + The shape of the result. If both `shape` and `axes` (see below) are + None, `shape` is ``x.shape``; if `shape` is None but `axes` is + not None, then `shape` is ``scipy.take(x.shape, axes, axis=0)``. + If ``shape[i] > x.shape[i]``, the i-th dimension is padded with zeros. + If ``shape[i] < x.shape[i]``, the i-th dimension is truncated to + length ``shape[i]``. + axes : tuple or None, optional + Axes along which the DCT is computed; the default is over all axes. + norm : {None, 'ortho'}, optional + Normalization mode (see Notes). Default is None. + overwrite_x : bool, optional + If True, the contents of `x` can be destroyed; the default is False. + + Returns + ------- + y : ndarray of real + The transformed input array. + + See Also + -------- + idstn : Inverse multidimensional DST + + Notes + ----- + For full details of the DST types and normalization modes, as well as + references, see `dst`. + + Examples + -------- + >>> from scipy.fftpack import dstn, idstn + >>> y = np.random.randn(16, 16) + >>> np.allclose(y, idstn(dstn(y, norm='ortho'), norm='ortho')) + True + + """ + x = np.asanyarray(x) + shape, axes = _init_nd_shape_and_axes(x, shape, axes) + for n, ax in zip(shape, axes): + x = dst(x, type=type, n=n, axis=ax, norm=norm, overwrite_x=overwrite_x) + return x + + +def idstn(x, type=2, shape=None, axes=None, norm=None, overwrite_x=False): + """ + Return multidimensional Discrete Sine Transform of x along the specified + axes. + + Parameters + ---------- + x : array_like + The input array. + type : {1, 2, 3}, optional + Type of the DCT (see Notes). Default type is 2. + shape : tuple of ints, optional + The shape of the result. If both `shape` and `axes` (see below) are + None, `shape` is ``x.shape``; if `shape` is None but `axes` is + not None, then `shape` is ``scipy.take(x.shape, axes, axis=0)``. + If ``shape[i] > x.shape[i]``, the i-th dimension is padded with zeros. + If ``shape[i] < x.shape[i]``, the i-th dimension is truncated to + length ``shape[i]``. + axes : tuple or None, optional + Axes along which the IDCT is computed; the default is over all axes. + norm : {None, 'ortho'}, optional + Normalization mode (see Notes). Default is None. + overwrite_x : bool, optional + If True, the contents of `x` can be destroyed; the default is False. + + Returns + ------- + y : ndarray of real + The transformed input array. + + See Also + -------- + dctn : multidimensional DST + + Notes + ----- + For full details of the IDST types and normalization modes, as well as + references, see `idst`. + + Examples + -------- + >>> from scipy.fftpack import dstn, idstn + >>> y = np.random.randn(16, 16) + >>> np.allclose(y, idstn(dctn(y, norm='ortho'), norm='ortho')) + True + """ + x = np.asanyarray(x) + shape, axes = _init_nd_shape_and_axes(x, shape, axes) + for n, ax in zip(shape, axes): + x = idst(x, type=type, n=n, axis=ax, norm=norm, + overwrite_x=overwrite_x) + return x + + def dct(x, type=2, n=None, axis=-1, norm=None, overwrite_x=False): """ Return the Discrete Cosine Transform of arbitrary type sequence x. diff --git a/scipy/fftpack/tests/test_real_transforms.py b/scipy/fftpack/tests/test_real_transforms.py index 9f371a929e56..9911b697c61e 100644 --- a/scipy/fftpack/tests/test_real_transforms.py +++ b/scipy/fftpack/tests/test_real_transforms.py @@ -5,7 +5,8 @@ import numpy as np from numpy.testing import assert_array_almost_equal, assert_equal -from scipy.fftpack.realtransforms import dct, idct, dst, idst, dctn, idctn +from scipy.fftpack.realtransforms import ( + dct, idct, dst, idst, dctn, idctn, dstn, idstn) # Matlab reference data MDATA = np.load(join(dirname(__file__), 'test.npz')) @@ -68,6 +69,26 @@ def idct_2d_ref(x, **kwargs): return x +def dst_2d_ref(x, **kwargs): + """ used as a reference in testing dst2. """ + x = np.array(x, copy=True) + for row in range(x.shape[0]): + x[row, :] = dst(x[row, :], **kwargs) + for col in range(x.shape[1]): + x[:, col] = dst(x[:, col], **kwargs) + return x + + +def idst_2d_ref(x, **kwargs): + """ used as a reference in testing idst2. """ + x = np.array(x, copy=True) + for row in range(x.shape[0]): + x[row, :] = idst(x[row, :], **kwargs) + for col in range(x.shape[1]): + x[:, col] = idst(x[:, col], **kwargs) + return x + + class TestComplex(TestCase): def test_dct_complex64(self): y = dct(1j*np.arange(5, dtype=np.complex64)) @@ -549,32 +570,52 @@ class Test_DCTN_IDCTN(TestCase): rstate = np.random.RandomState(1234) shape = (32, 16) data = rstate.randn(*shape) + # Sets of functions to test + function_sets = [dict(forward=dctn, + inverse=idctn, + forward_ref=dct_2d_ref, + inverse_ref=idct_2d_ref), + dict(forward=dstn, + inverse=idstn, + forward_ref=dst_2d_ref, + inverse_ref=idst_2d_ref), ] def test_axes_round_trip(self): norm = 'ortho' - for axes in [None, (1, ), (0, ), (0, 1), (-2, -1)]: - for dct_type in self.types: - if norm == 'ortho' and dct_type == 1: - continue # 'ortho' not supported by DCT-I - tmp = dctn(self.data, type=dct_type, axes=axes, norm=norm) - tmp = idctn(tmp, type=dct_type, axes=axes, norm=norm) - assert_array_almost_equal(self.data, tmp, decimal=self.dec) + for function_set in self.function_sets: + fforward = function_set['forward'] + finverse = function_set['inverse'] + for axes in [None, (1, ), (0, ), (0, 1), (-2, -1)]: + for dct_type in self.types: + if norm == 'ortho' and dct_type == 1: + continue # 'ortho' not supported by DCT-I + tmp = fforward(self.data, type=dct_type, axes=axes, + norm=norm) + tmp = finverse(tmp, type=dct_type, axes=axes, norm=norm) + assert_array_almost_equal(self.data, tmp, decimal=self.dec) def test_dctn_vs_2d_reference(self): - for dct_type in self.types: - for norm in self.norms: - if norm == 'ortho' and dct_type == 1: - continue # 'ortho' not supported by DCT-I - y1 = dctn(self.data, type=dct_type, axes=None, norm=norm) - y2 = dct_2d_ref(self.data, type=dct_type, norm=norm) - assert_array_almost_equal(y1, y2, decimal=11) + for function_set in self.function_sets: + fforward = function_set['forward'] + fforward_ref = function_set['forward_ref'] + for dct_type in self.types: + for norm in self.norms: + if norm == 'ortho' and dct_type == 1: + continue # 'ortho' not supported by DCT-I + y1 = fforward(self.data, type=dct_type, axes=None, + norm=norm) + y2 = fforward_ref(self.data, type=dct_type, norm=norm) + assert_array_almost_equal(y1, y2, decimal=11) def test_idctn_vs_2d_reference(self): - for dct_type in self.types: - for norm in self.norms: - if norm == 'ortho' and dct_type == 1: - continue # 'ortho' not supported by DCT-I - fdata = dctn(self.data, type=dct_type, norm=norm) - y1 = idctn(fdata, type=dct_type, norm=norm) - y2 = idct_2d_ref(fdata, type=dct_type, norm=norm) - assert_array_almost_equal(y1, y2, decimal=11) + for function_set in self.function_sets: + finverse = function_set['inverse'] + finverse_ref = function_set['inverse_ref'] + for dct_type in self.types: + for norm in self.norms: + if norm == 'ortho' and dct_type == 1: + continue # 'ortho' not supported by DCT-I + fdata = dctn(self.data, type=dct_type, norm=norm) + y1 = finverse(fdata, type=dct_type, norm=norm) + y2 = finverse_ref(fdata, type=dct_type, norm=norm) + assert_array_almost_equal(y1, y2, decimal=11) From d0a34f15c7747db8205498abfedec9b03f7a352a Mon Sep 17 00:00:00 2001 From: "Gregory R. Lee" Date: Mon, 21 Aug 2017 14:43:33 -0400 Subject: [PATCH 6/6] TST: dctn, idctn: add tests for shape and axes handling --- scipy/fftpack/__init__.py | 5 ++- scipy/fftpack/realtransforms.py | 16 ++++----- scipy/fftpack/tests/test_real_transforms.py | 36 +++++++++++++++++++-- 3 files changed, 44 insertions(+), 13 deletions(-) diff --git a/scipy/fftpack/__init__.py b/scipy/fftpack/__init__.py index 6f7bc61b25cf..6aaa27877465 100644 --- a/scipy/fftpack/__init__.py +++ b/scipy/fftpack/__init__.py @@ -23,6 +23,8 @@ idctn - n-dimensional Inverse discrete cosine transform dst - Discrete sine transform idst - Inverse discrete sine transform + dstn - n-dimensional Discrete sine transform + idstn - n-dimensional Inverse discrete sine transform Differential and pseudo-differential operators ============================================== @@ -104,7 +106,8 @@ del k, register_func from .realtransforms import * -__all__.extend(['dct', 'idct', 'dst', 'idst', 'dctn', 'idctn']) +__all__.extend(['dct', 'idct', 'dst', 'idst', 'dctn', 'idctn', 'dstn', + 'idstn']) from scipy._lib._testutils import PytestTester test = PytestTester(__name__) diff --git a/scipy/fftpack/realtransforms.py b/scipy/fftpack/realtransforms.py index 6707fd6817b3..32f763e08825 100644 --- a/scipy/fftpack/realtransforms.py +++ b/scipy/fftpack/realtransforms.py @@ -4,7 +4,7 @@ from __future__ import division, print_function, absolute_import -__all__ = ['dct', 'idct', 'dst', 'idst', 'dctn', 'idctn'] +__all__ = ['dct', 'idct', 'dst', 'idst', 'dctn', 'idctn', 'dstn', 'idstn'] import numpy as np from scipy.fftpack import _fftpack @@ -50,8 +50,7 @@ def _init_nd_shape_and_axes(x, shape, axes): def dctn(x, type=2, shape=None, axes=None, norm=None, overwrite_x=False): """ - Return multidimensional Discrete Cosine Transform of x along the specified - axes. + Return multidimensional Discrete Cosine Transform along the specified axes. Parameters ---------- @@ -104,8 +103,7 @@ def dctn(x, type=2, shape=None, axes=None, norm=None, overwrite_x=False): def idctn(x, type=2, shape=None, axes=None, norm=None, overwrite_x=False): """ - Return multidimensional Discrete Cosine Transform of x along the specified - axes. + Return multidimensional Discrete Cosine Transform along the specified axes. Parameters ---------- @@ -158,8 +156,7 @@ def idctn(x, type=2, shape=None, axes=None, norm=None, overwrite_x=False): def dstn(x, type=2, shape=None, axes=None, norm=None, overwrite_x=False): """ - Return multidimensional Discrete Sine Transform of x along the specified - axes. + Return multidimensional Discrete Sine Transform along the specified axes. Parameters ---------- @@ -212,8 +209,7 @@ def dstn(x, type=2, shape=None, axes=None, norm=None, overwrite_x=False): def idstn(x, type=2, shape=None, axes=None, norm=None, overwrite_x=False): """ - Return multidimensional Discrete Sine Transform of x along the specified - axes. + Return multidimensional Discrete Sine Transform along the specified axes. Parameters ---------- @@ -253,7 +249,7 @@ def idstn(x, type=2, shape=None, axes=None, norm=None, overwrite_x=False): -------- >>> from scipy.fftpack import dstn, idstn >>> y = np.random.randn(16, 16) - >>> np.allclose(y, idstn(dctn(y, norm='ortho'), norm='ortho')) + >>> np.allclose(y, idstn(dstn(y, norm='ortho'), norm='ortho')) True """ x = np.asanyarray(x) diff --git a/scipy/fftpack/tests/test_real_transforms.py b/scipy/fftpack/tests/test_real_transforms.py index 9911b697c61e..05dd06f83286 100644 --- a/scipy/fftpack/tests/test_real_transforms.py +++ b/scipy/fftpack/tests/test_real_transforms.py @@ -4,6 +4,7 @@ import numpy as np from numpy.testing import assert_array_almost_equal, assert_equal +from pytest import raises as assert_raises from scipy.fftpack.realtransforms import ( dct, idct, dst, idst, dctn, idctn, dstn, idstn) @@ -89,7 +90,7 @@ def idst_2d_ref(x, **kwargs): return x -class TestComplex(TestCase): +class TestComplex(object): def test_dct_complex64(self): y = dct(1j*np.arange(5, dtype=np.complex64)) x = 1j*dct(np.arange(5)) @@ -563,7 +564,7 @@ def test_idst(self): self._check_1d(idst, dtype, (2, 16), 1, overwritable) -class Test_DCTN_IDCTN(TestCase): +class Test_DCTN_IDCTN(object): dec = 14 types = [1, 2, 3] norms = [None, 'ortho'] @@ -613,9 +614,40 @@ def test_idctn_vs_2d_reference(self): finverse_ref = function_set['inverse_ref'] for dct_type in self.types: for norm in self.norms: + print(function_set, dct_type, norm) if norm == 'ortho' and dct_type == 1: continue # 'ortho' not supported by DCT-I fdata = dctn(self.data, type=dct_type, norm=norm) y1 = finverse(fdata, type=dct_type, norm=norm) y2 = finverse_ref(fdata, type=dct_type, norm=norm) assert_array_almost_equal(y1, y2, decimal=11) + + def test_axes_and_shape(self): + for function_set in self.function_sets: + fforward = function_set['forward'] + finverse = function_set['inverse'] + + # shape must match the number of axes + assert_raises(ValueError, fforward, self.data, + shape=(self.data.shape[0], ), + axes=(0, 1)) + assert_raises(ValueError, fforward, self.data, + shape=(self.data.shape[0], ), + axes=None) + assert_raises(ValueError, fforward, self.data, + shape=self.data.shape, + axes=(0, )) + # shape must be a tuple + assert_raises(TypeError, fforward, self.data, + shape=self.data.shape[0], + axes=(0, 1)) + + # shape=None works with a subset of axes + for axes in [(0, ), (1, )]: + tmp = fforward(self.data, shape=None, axes=axes, norm='ortho') + tmp = finverse(tmp, shape=None, axes=axes, norm='ortho') + assert_array_almost_equal(self.data, tmp, decimal=self.dec) + + # non-default shape + tmp = fforward(self.data, shape=(128, 128), axes=None) + assert_equal(tmp.shape, (128, 128))