Skip to content

Commit

Permalink
MAINT/TST: Add test for require, stop making extra copies.
Browse files Browse the repository at this point in the history
Also add ENSUREARRAY and a note about  requiring native
byteorder to the docs.
  • Loading branch information
ewmoore authored and charris committed Jan 17, 2015
1 parent 1444550 commit e744b03
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 18 deletions.
43 changes: 25 additions & 18 deletions numpy/core/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,9 @@ def require(a, dtype=None, requirements=None):
a : array_like
The object to be converted to a type-and-requirement-satisfying array.
dtype : data-type
The required data-type, the default data-type is float64).
The required data-type. If None preserve the current dtype. If your
application requires the data to be in native byteorder, include
a byteorder specification as a part of the dtype specification.
requirements : str or list of str
The requirements list can be any of the following
Expand All @@ -606,6 +608,7 @@ def require(a, dtype=None, requirements=None):
* 'ALIGNED' ('A') - ensure a data-type aligned array
* 'WRITEABLE' ('W') - ensure a writable array
* 'OWNDATA' ('O') - ensure an array that owns its own data
* 'ENSUREARRAY', ('E') - ensure a base array, instead of a subclass
See Also
--------
Expand Down Expand Up @@ -642,34 +645,38 @@ def require(a, dtype=None, requirements=None):
UPDATEIFCOPY : False
"""
if requirements is None:
requirements = []
else:
requirements = [x.upper() for x in requirements]

possible_flags = {'C':'C', 'C_CONTIGUOUS':'C', 'CONTIGUOUS':'C',
'F':'F', 'F_CONTIGUOUS':'F', 'FORTRAN':'F',
'A':'A', 'ALIGNED':'A',
'W':'W', 'WRITEABLE':'W',
'O':'O', 'OWNDATA':'O',
'E':'E', 'ENSUREARRAY':'E'}
if not requirements:
return asanyarray(a, dtype=dtype)
else:
requirements = set(possible_flags[x.upper()] for x in requirements)

if 'ENSUREARRAY' in requirements or 'E' in requirements:
if 'E' in requirements:
requirements.remove('E')
subok = False
else:
subok = True

arr = array(a, dtype=dtype, copy=False, subok=subok)
order = 'A'
if requirements >= set(['C', 'F']):
raise ValueError('Cannot specify both "C" and "F" order')
elif 'F' in requirements:
order = 'F'
requirements.remove('F')
elif 'C' in requirements:
order = 'C'
requirements.remove('C')

copychar = 'A'
if 'FORTRAN' in requirements or \
'F_CONTIGUOUS' in requirements or \
'F' in requirements:
copychar = 'F'
elif 'CONTIGUOUS' in requirements or \
'C_CONTIGUOUS' in requirements or \
'C' in requirements:
copychar = 'C'
arr = array(a, dtype=dtype, order=order, copy=False, subok=subok)

for prop in requirements:
if not arr.flags[prop]:
arr = arr.copy(copychar)
arr = arr.copy(order)
break
return arr

Expand Down
74 changes: 74 additions & 0 deletions numpy/core/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2148,5 +2148,79 @@ def test_outer_out_param():
assert_equal(res1, out1)
assert_equal(np.outer(arr2, arr3, out2), out2)

class TestRequire(object):
flag_names = ['C', 'C_CONTIGUOUS', 'CONTIGUOUS',
'F', 'F_CONTIGUOUS', 'FORTRAN',
'A', 'ALIGNED',
'W', 'WRITEABLE',
'O', 'OWNDATA']

def generate_all_false(self, dtype):
arr = np.zeros((2, 2), [('junk', 'i1'), ('a', dtype)])
arr.setflags(write=False)
a = arr['a']
assert_(not a.flags['C'])
assert_(not a.flags['F'])
assert_(not a.flags['O'])
assert_(not a.flags['W'])
assert_(not a.flags['A'])
return a

def set_and_check_flag(self, flag, dtype, arr):
if dtype is None:
dtype = arr.dtype
b = np.require(arr, dtype, [flag])
assert_(b.flags[flag])
assert_(b.dtype == dtype)

# a further call to np.require ought to return the same array
# unless OWNDATA is specified.
c = np.require(b, None, [flag])
if flag[0] != 'O':
assert_(c is b)
else:
assert_(c.flags[flag])

def test_require_each(self):

id = ['f8', 'i4']
fd = [None, 'f8', 'c16']
for idtype, fdtype, flag in itertools.product(id, fd, self.flag_names):
a = self.generate_all_false(idtype)
yield self.set_and_check_flag, flag, fdtype, a

def test_unknown_requirement(self):
a = self.generate_all_false('f8')
assert_raises(KeyError, np.require, a, None, 'Q')

def test_non_array_input(self):
a = np.require([1, 2, 3, 4], 'i4', ['C', 'A', 'O'])
assert_(a.flags['O'])
assert_(a.flags['C'])
assert_(a.flags['A'])
assert_(a.dtype == 'i4')
assert_equal(a, [1, 2, 3, 4])

def test_C_and_F_simul(self):
a = self.generate_all_false('f8')
assert_raises(ValueError, np.require, a, None, ['C', 'F'])

def test_ensure_array(self):
class ArraySubclass(ndarray):
pass

a = ArraySubclass((2,2))
b = np.require(a, None, ['E'])
assert_(type(b) is np.ndarray)

def test_preserve_subtype(self):
class ArraySubclass(ndarray):
pass

for flag in self.flag_names:
a = ArraySubclass((2,2))
yield self.set_and_check_flag, flag, None, a


if __name__ == "__main__":
run_module_suite()

0 comments on commit e744b03

Please sign in to comment.