Skip to content

Commit

Permalink
Merge pull request numpy#5468 from jaimefrio/swapaxes_view
Browse files Browse the repository at this point in the history
ENH: Make swapaxes always return a view. Fixes numpy#5260
  • Loading branch information
charris committed Jan 21, 2015
2 parents e73d4fc + a7fdf04 commit 960433e
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 32 deletions.
9 changes: 7 additions & 2 deletions doc/release/1.10.0-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,17 @@ the case of matrices. Matrices are special cased for backward
compatibility and still return 1-D arrays as before. If you need to
preserve the matrix subtype, use the methods instead of the functions.

*rollaxis* always returns a view
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
*rollaxis* and *swapaxes* always return a view
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Previously, a view was returned except when no change was made in the order
of the axes, in which case the input array was returned. A view is now
returned in all cases.

C API
~~~~~
The changes to *swapaxes* also apply to the *PyArray_SwapAxes* C function,
which now returns a view in all cases.


New Features
============
Expand Down
6 changes: 4 additions & 2 deletions numpy/core/fromnumeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,10 @@ def swapaxes(a, axis1, axis2):
Returns
-------
a_swapped : ndarray
If `a` is an ndarray, then a view of `a` is returned; otherwise
a new array is created.
For Numpy >= 1.10, if `a` is an ndarray, then a view of `a` is
returned; otherwise a new array is created. For earlier Numpy
versions a view of `a` is returned only if the order of the
axes is changed, otherwise the input array is returned.
Examples
--------
Expand Down
40 changes: 12 additions & 28 deletions numpy/core/src/multiarray/shape.c
Original file line number Diff line number Diff line change
Expand Up @@ -653,19 +653,8 @@ PyArray_SwapAxes(PyArrayObject *ap, int a1, int a2)
{
PyArray_Dims new_axes;
npy_intp dims[NPY_MAXDIMS];
int n, i, val;
PyObject *ret;

if (a1 == a2) {
Py_INCREF(ap);
return (PyObject *)ap;
}

n = PyArray_NDIM(ap);
if (n <= 1) {
Py_INCREF(ap);
return (PyObject *)ap;
}
int n = PyArray_NDIM(ap);
int i;

if (a1 < 0) {
a1 += n;
Expand All @@ -683,25 +672,20 @@ PyArray_SwapAxes(PyArrayObject *ap, int a1, int a2)
"bad axis2 argument to swapaxes");
return NULL;
}

for (i = 0; i < n; ++i) {
dims[i] = i;
}
dims[a1] = a2;
dims[a2] = a1;

new_axes.ptr = dims;
new_axes.len = n;

for (i = 0; i < n; i++) {
if (i == a1) {
val = a2;
}
else if (i == a2) {
val = a1;
}
else {
val = i;
}
new_axes.ptr[i] = val;
}
ret = PyArray_Transpose(ap, &new_axes);
return ret;
return PyArray_Transpose(ap, &new_axes);
}


/*NUMPY_API
* Return Transpose.
*/
Expand Down Expand Up @@ -969,7 +953,7 @@ PyArray_Ravel(PyArrayObject *arr, NPY_ORDER order)

PyArray_CreateSortedStridePerm(PyArray_NDIM(arr),
PyArray_STRIDES(arr), strideperm);

for (i = ndim-1; i >= 0; --i) {
if (PyArray_DIM(arr, strideperm[i].perm) == 1) {
/* A size one dimension does not matter */
Expand Down
32 changes: 32 additions & 0 deletions numpy/core/tests/test_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1887,6 +1887,38 @@ def test_ravel(self):
assert_equal(a.ravel('A'), [0, 2, 4, 6, 8, 10, 12, 14])
assert_equal(a.ravel('F'), [0, 8, 4, 12, 2, 10, 6, 14])

def test_swapaxes(self):
a = np.arange(1*2*3*4).reshape(1, 2, 3, 4).copy()
idx = np.indices(a.shape)
assert_(a.flags['OWNDATA'])
b = a.copy()
# check exceptions
assert_raises(ValueError, a.swapaxes, -5, 0)
assert_raises(ValueError, a.swapaxes, 4, 0)
assert_raises(ValueError, a.swapaxes, 0, -5)
assert_raises(ValueError, a.swapaxes, 0, 4)

for i in range(-4, 4):
for j in range(-4, 4):
for k, src in enumerate((a, b)):
c = src.swapaxes(i, j)
# check shape
shape = list(src.shape)
shape[i] = src.shape[j]
shape[j] = src.shape[i]
assert_equal(c.shape, shape, str((i, j, k)))
# check array contents
i0, i1, i2, i3 = [dim-1 for dim in c.shape]
j0, j1, j2, j3 = [dim-1 for dim in src.shape]
assert_equal(src[idx[j0], idx[j1], idx[j2], idx[j3]],
c[idx[i0], idx[i1], idx[i2], idx[i3]],
str((i, j, k)))
# check a view is always returned, gh-5260
assert_(not c.flags['OWNDATA'], str((i, j, k)))
# check on non-contiguous input array
if k == 1:
b = c

def test_conjugate(self):
a = np.array([1-1j, 1+1j, 23+23.0j])
ac = a.conj()
Expand Down

0 comments on commit 960433e

Please sign in to comment.