From e744b031c20a9ca8fd550f2c51637cdbc8a40307 Mon Sep 17 00:00:00 2001 From: Eric Moore Date: Wed, 14 Jan 2015 09:22:26 -0500 Subject: [PATCH] MAINT/TST: Add test for require, stop making extra copies. Also add ENSUREARRAY and a note about requiring native byteorder to the docs. --- numpy/core/numeric.py | 43 +++++++++++-------- numpy/core/tests/test_numeric.py | 74 ++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 18 deletions(-) diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index c1c55517222e..430f7a7157b7 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -597,7 +597,9 @@ def require(a, dtype=None, requirements=None): a : array_like The object to be converted to a type-and-requirement-satisfying array. dtype : data-type - The required data-type, the default data-type is float64). + The required data-type. If None preserve the current dtype. If your + application requires the data to be in native byteorder, include + a byteorder specification as a part of the dtype specification. requirements : str or list of str The requirements list can be any of the following @@ -606,6 +608,7 @@ def require(a, dtype=None, requirements=None): * 'ALIGNED' ('A') - ensure a data-type aligned array * 'WRITEABLE' ('W') - ensure a writable array * 'OWNDATA' ('O') - ensure an array that owns its own data + * 'ENSUREARRAY', ('E') - ensure a base array, instead of a subclass See Also -------- @@ -642,34 +645,38 @@ def require(a, dtype=None, requirements=None): UPDATEIFCOPY : False """ - if requirements is None: - requirements = [] - else: - requirements = [x.upper() for x in requirements] - + possible_flags = {'C':'C', 'C_CONTIGUOUS':'C', 'CONTIGUOUS':'C', + 'F':'F', 'F_CONTIGUOUS':'F', 'FORTRAN':'F', + 'A':'A', 'ALIGNED':'A', + 'W':'W', 'WRITEABLE':'W', + 'O':'O', 'OWNDATA':'O', + 'E':'E', 'ENSUREARRAY':'E'} if not requirements: return asanyarray(a, dtype=dtype) + else: + requirements = set(possible_flags[x.upper()] for x in requirements) - if 'ENSUREARRAY' in requirements or 'E' in requirements: + if 'E' in requirements: + requirements.remove('E') subok = False else: subok = True - arr = array(a, dtype=dtype, copy=False, subok=subok) + order = 'A' + if requirements >= set(['C', 'F']): + raise ValueError('Cannot specify both "C" and "F" order') + elif 'F' in requirements: + order = 'F' + requirements.remove('F') + elif 'C' in requirements: + order = 'C' + requirements.remove('C') - copychar = 'A' - if 'FORTRAN' in requirements or \ - 'F_CONTIGUOUS' in requirements or \ - 'F' in requirements: - copychar = 'F' - elif 'CONTIGUOUS' in requirements or \ - 'C_CONTIGUOUS' in requirements or \ - 'C' in requirements: - copychar = 'C' + arr = array(a, dtype=dtype, order=order, copy=False, subok=subok) for prop in requirements: if not arr.flags[prop]: - arr = arr.copy(copychar) + arr = arr.copy(order) break return arr diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index d8b01a532389..b151e24f3b36 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -2148,5 +2148,79 @@ def test_outer_out_param(): assert_equal(res1, out1) assert_equal(np.outer(arr2, arr3, out2), out2) +class TestRequire(object): + flag_names = ['C', 'C_CONTIGUOUS', 'CONTIGUOUS', + 'F', 'F_CONTIGUOUS', 'FORTRAN', + 'A', 'ALIGNED', + 'W', 'WRITEABLE', + 'O', 'OWNDATA'] + + def generate_all_false(self, dtype): + arr = np.zeros((2, 2), [('junk', 'i1'), ('a', dtype)]) + arr.setflags(write=False) + a = arr['a'] + assert_(not a.flags['C']) + assert_(not a.flags['F']) + assert_(not a.flags['O']) + assert_(not a.flags['W']) + assert_(not a.flags['A']) + return a + + def set_and_check_flag(self, flag, dtype, arr): + if dtype is None: + dtype = arr.dtype + b = np.require(arr, dtype, [flag]) + assert_(b.flags[flag]) + assert_(b.dtype == dtype) + + # a further call to np.require ought to return the same array + # unless OWNDATA is specified. + c = np.require(b, None, [flag]) + if flag[0] != 'O': + assert_(c is b) + else: + assert_(c.flags[flag]) + + def test_require_each(self): + + id = ['f8', 'i4'] + fd = [None, 'f8', 'c16'] + for idtype, fdtype, flag in itertools.product(id, fd, self.flag_names): + a = self.generate_all_false(idtype) + yield self.set_and_check_flag, flag, fdtype, a + + def test_unknown_requirement(self): + a = self.generate_all_false('f8') + assert_raises(KeyError, np.require, a, None, 'Q') + + def test_non_array_input(self): + a = np.require([1, 2, 3, 4], 'i4', ['C', 'A', 'O']) + assert_(a.flags['O']) + assert_(a.flags['C']) + assert_(a.flags['A']) + assert_(a.dtype == 'i4') + assert_equal(a, [1, 2, 3, 4]) + + def test_C_and_F_simul(self): + a = self.generate_all_false('f8') + assert_raises(ValueError, np.require, a, None, ['C', 'F']) + + def test_ensure_array(self): + class ArraySubclass(ndarray): + pass + + a = ArraySubclass((2,2)) + b = np.require(a, None, ['E']) + assert_(type(b) is np.ndarray) + + def test_preserve_subtype(self): + class ArraySubclass(ndarray): + pass + + for flag in self.flag_names: + a = ArraySubclass((2,2)) + yield self.set_and_check_flag, flag, None, a + + if __name__ == "__main__": run_module_suite()