Skip to content

Commit

Permalink
Merge pull request numpy#6679 from pv/may-share-memory-fix
Browse files Browse the repository at this point in the history
Improve may_share_memory performance + fix non-ndarray inputs
  • Loading branch information
juliantaylor committed Nov 15, 2015
2 parents 8ae543c + f2be3a2 commit eeba2cb
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 59 deletions.
39 changes: 39 additions & 0 deletions numpy/add_newdocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3826,6 +3826,45 @@ def luf(lamdaexpr, *args, **kwargs):
""")


add_newdoc('numpy.core.multiarray', 'may_share_memory',
"""
may_share_memory(a, b, max_work=None)
Determine if two arrays might share memory
A return of True does not necessarily mean that the two arrays
share any element. It just means that they *might*.
Only the memory bounds of a and b are checked by default.
Parameters
----------
a, b : ndarray
Input arrays
max_work : int, optional
Effort to spend on solving the overlap problem. See
`shares_memory` for details. Default for ``may_share_memory``
is to do a bounds check.
Returns
-------
out : bool
See Also
--------
shares_memory
Examples
--------
>>> np.may_share_memory(np.array([1,2]), np.array([5,8,9]))
False
>>> x = np.zeros([3, 4])
>>> np.may_share_memory(x[:,0], x[:,1])
True
""")


add_newdoc('numpy.core.multiarray', 'ndarray', ('newbyteorder',
"""
arr.newbyteorder(new_order='S')
Expand Down
45 changes: 1 addition & 44 deletions numpy/core/function_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import division, absolute_import, print_function

__all__ = ['logspace', 'linspace', 'may_share_memory']
__all__ = ['logspace', 'linspace']

from . import numeric as _nx
from .numeric import result_type, NaN, shares_memory, MAY_SHARE_BOUNDS, TooHardError
Expand Down Expand Up @@ -201,46 +201,3 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None):
if dtype is None:
return _nx.power(base, y)
return _nx.power(base, y).astype(dtype)


def may_share_memory(a, b, max_work=None):
"""Determine if two arrays can share memory
A return of True does not necessarily mean that the two arrays
share any element. It just means that they *might*.
Only the memory bounds of a and b are checked by default.
Parameters
----------
a, b : ndarray
Input arrays
max_work : int, optional
Effort to spend on solving the overlap problem. See
`shares_memory` for details. Default for ``may_share_memory``
is to do a bounds check.
Returns
-------
out : bool
See Also
--------
shares_memory
Examples
--------
>>> np.may_share_memory(np.array([1,2]), np.array([5,8,9]))
False
>>> x = np.zeros([3, 4])
>>> np.may_share_memory(x[:,0], x[:,1])
True
"""
if max_work is None:
max_work = MAY_SHARE_BOUNDS
try:
return shares_memory(a, b, max_work=max_work)
except (TooHardError, OverflowError):
# Unable to determine, assume yes
return True
4 changes: 3 additions & 1 deletion numpy/core/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
'Inf', 'inf', 'infty', 'Infinity', 'nan', 'NaN', 'False_', 'True_',
'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE',
'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like', 'matmul',
'shares_memory', 'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT', 'TooHardError',
'shares_memory', 'may_share_memory', 'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT',
'TooHardError',
]

if sys.version_info[0] < 3:
Expand Down Expand Up @@ -384,6 +385,7 @@ def extend_all(module):
fromfile = multiarray.fromfile
frombuffer = multiarray.frombuffer
shares_memory = multiarray.shares_memory
may_share_memory = multiarray.may_share_memory
if sys.version_info[0] < 3:
newbuffer = multiarray.newbuffer
getbuffer = multiarray.getbuffer
Expand Down
84 changes: 70 additions & 14 deletions numpy/core/src/multiarray/multiarraymodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -3989,25 +3989,52 @@ test_interrupt(PyObject *NPY_UNUSED(self), PyObject *args)


static PyObject *
array_shares_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
array_shares_memory_impl(PyObject *args, PyObject *kwds, Py_ssize_t default_max_work,
int raise_exceptions)
{
PyObject * self_obj = NULL;
PyObject * other_obj = NULL;
PyArrayObject * self = NULL;
PyArrayObject * other = NULL;
PyObject *max_work_obj = NULL;
static char *kwlist[] = {"self", "other", "max_work", NULL};

mem_overlap_t result;
static PyObject *too_hard_cls = NULL;
Py_ssize_t max_work = NPY_MAY_SHARE_EXACT;
Py_ssize_t max_work;
NPY_BEGIN_THREADS_DEF;

if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O&|O", kwlist,
PyArray_Converter, &self,
PyArray_Converter, &other,
&max_work_obj)) {
max_work = default_max_work;

if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|O", kwlist,
&self_obj, &other_obj, &max_work_obj)) {
return NULL;
}

if (PyArray_Check(self_obj)) {
self = (PyArrayObject*)self_obj;
Py_INCREF(self);
}
else {
/* Use FromAny to enable checking overlap for objects exposing array
interfaces etc. */
self = (PyArrayObject*)PyArray_FromAny(self_obj, NULL, 0, 0, 0, NULL);
if (self == NULL) {
goto fail;
}
}

if (PyArray_Check(other_obj)) {
other = (PyArrayObject*)other_obj;
Py_INCREF(other);
}
else {
other = (PyArrayObject*)PyArray_FromAny(other_obj, NULL, 0, 0, 0, NULL);
if (other == NULL) {
goto fail;
}
}

if (max_work_obj == NULL || max_work_obj == Py_None) {
/* noop */
}
Expand Down Expand Up @@ -4043,17 +4070,29 @@ array_shares_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwd
Py_RETURN_TRUE;
}
else if (result == MEM_OVERLAP_OVERFLOW) {
PyErr_SetString(PyExc_OverflowError,
"Integer overflow in computing overlap");
return NULL;
if (raise_exceptions) {
PyErr_SetString(PyExc_OverflowError,
"Integer overflow in computing overlap");
return NULL;
}
else {
/* Don't know, so say yes */
Py_RETURN_TRUE;
}
}
else if (result == MEM_OVERLAP_TOO_HARD) {
npy_cache_import("numpy.core._internal", "TooHardError",
&too_hard_cls);
if (too_hard_cls) {
PyErr_SetString(too_hard_cls, "Exceeded max_work");
if (raise_exceptions) {
npy_cache_import("numpy.core._internal", "TooHardError",
&too_hard_cls);
if (too_hard_cls) {
PyErr_SetString(too_hard_cls, "Exceeded max_work");
}
return NULL;
}
else {
/* Don't know, so say yes */
Py_RETURN_TRUE;
}
return NULL;
}
else {
/* Doesn't happen usually */
Expand All @@ -4069,6 +4108,20 @@ array_shares_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwd
}


static PyObject *
array_shares_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
{
return array_shares_memory_impl(args, kwds, NPY_MAY_SHARE_EXACT, 1);
}


static PyObject *
array_may_share_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
{
return array_shares_memory_impl(args, kwds, NPY_MAY_SHARE_BOUNDS, 0);
}


static struct PyMethodDef array_module_methods[] = {
{"_get_ndarray_c_version",
(PyCFunction)array__get_ndarray_c_version,
Expand Down Expand Up @@ -4178,6 +4231,9 @@ static struct PyMethodDef array_module_methods[] = {
{"shares_memory",
(PyCFunction)array_shares_memory,
METH_VARARGS | METH_KEYWORDS, NULL},
{"may_share_memory",
(PyCFunction)array_may_share_memory,
METH_VARARGS | METH_KEYWORDS, NULL},
/* Datetime-related functions */
{"datetime_data",
(PyCFunction)array_datetime_data,
Expand Down
28 changes: 28 additions & 0 deletions numpy/core/tests/test_mem_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,5 +482,33 @@ def test_internal_overlap_fuzz():
no_overlap += 1


def test_non_ndarray_inputs():
# Regression check for gh-5604

class MyArray(object):
def __init__(self, data):
self.data = data

@property
def __array_interface__(self):
return self.data.__array_interface__

class MyArray2(object):
def __init__(self, data):
self.data = data

def __array__(self):
return self.data

for cls in [MyArray, MyArray2]:
x = np.arange(5)

assert_(np.may_share_memory(cls(x[::2]), x[1::2]))
assert_(not np.shares_memory(cls(x[::2]), x[1::2]))

assert_(np.shares_memory(cls(x[1::3]), x[::2]))
assert_(np.may_share_memory(cls(x[1::3]), x[::2]))


if __name__ == "__main__":
run_module_suite()

0 comments on commit eeba2cb

Please sign in to comment.