Skip to content

Commit

Permalink
ENH: reimplement may_share_memory in C to improve its performance
Browse files Browse the repository at this point in the history
  • Loading branch information
pv committed Nov 12, 2015
1 parent 4be9ce7 commit 8efc87e
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 55 deletions.
39 changes: 39 additions & 0 deletions numpy/add_newdocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3826,6 +3826,45 @@ def luf(lamdaexpr, *args, **kwargs):
""")


add_newdoc('numpy.core.multiarray', 'may_share_memory',
"""
may_share_memory(a, b, max_work=None)
Determine if two arrays might share memory
A return of True does not necessarily mean that the two arrays
share any element. It just means that they *might*.
Only the memory bounds of a and b are checked by default.
Parameters
----------
a, b : ndarray
Input arrays
max_work : int, optional
Effort to spend on solving the overlap problem. See
`shares_memory` for details. Default for ``may_share_memory``
is to do a bounds check.
Returns
-------
out : bool
See Also
--------
shares_memory
Examples
--------
>>> np.may_share_memory(np.array([1,2]), np.array([5,8,9]))
False
>>> x = np.zeros([3, 4])
>>> np.may_share_memory(x[:,0], x[:,1])
True
""")


add_newdoc('numpy.core.multiarray', 'ndarray', ('newbyteorder',
"""
arr.newbyteorder(new_order='S')
Expand Down
45 changes: 1 addition & 44 deletions numpy/core/function_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import division, absolute_import, print_function

__all__ = ['logspace', 'linspace', 'may_share_memory']
__all__ = ['logspace', 'linspace']

from . import numeric as _nx
from .numeric import result_type, NaN, shares_memory, MAY_SHARE_BOUNDS, TooHardError
Expand Down Expand Up @@ -201,46 +201,3 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None):
if dtype is None:
return _nx.power(base, y)
return _nx.power(base, y).astype(dtype)


def may_share_memory(a, b, max_work=None):
"""Determine if two arrays can share memory
A return of True does not necessarily mean that the two arrays
share any element. It just means that they *might*.
Only the memory bounds of a and b are checked by default.
Parameters
----------
a, b : ndarray
Input arrays
max_work : int, optional
Effort to spend on solving the overlap problem. See
`shares_memory` for details. Default for ``may_share_memory``
is to do a bounds check.
Returns
-------
out : bool
See Also
--------
shares_memory
Examples
--------
>>> np.may_share_memory(np.array([1,2]), np.array([5,8,9]))
False
>>> x = np.zeros([3, 4])
>>> np.may_share_memory(x[:,0], x[:,1])
True
"""
if max_work is None:
max_work = MAY_SHARE_BOUNDS
try:
return shares_memory(a, b, max_work=max_work)
except (TooHardError, OverflowError):
# Unable to determine, assume yes
return True
4 changes: 3 additions & 1 deletion numpy/core/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
'Inf', 'inf', 'infty', 'Infinity', 'nan', 'NaN', 'False_', 'True_',
'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE',
'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like', 'matmul',
'shares_memory', 'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT', 'TooHardError',
'shares_memory', 'may_share_memory', 'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT',
'TooHardError',
]

if sys.version_info[0] < 3:
Expand Down Expand Up @@ -384,6 +385,7 @@ def extend_all(module):
fromfile = multiarray.fromfile
frombuffer = multiarray.frombuffer
shares_memory = multiarray.shares_memory
may_share_memory = multiarray.may_share_memory
if sys.version_info[0] < 3:
newbuffer = multiarray.newbuffer
getbuffer = multiarray.getbuffer
Expand Down
52 changes: 42 additions & 10 deletions numpy/core/src/multiarray/multiarraymodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -3989,7 +3989,8 @@ test_interrupt(PyObject *NPY_UNUSED(self), PyObject *args)


static PyObject *
array_shares_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
array_shares_memory_impl(PyObject *args, PyObject *kwds, Py_ssize_t default_max_work,
int raise_exceptions)
{
PyArrayObject * self = NULL;
PyArrayObject * other = NULL;
Expand All @@ -3998,9 +3999,11 @@ array_shares_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwd

mem_overlap_t result;
static PyObject *too_hard_cls = NULL;
Py_ssize_t max_work = NPY_MAY_SHARE_EXACT;
Py_ssize_t max_work;
NPY_BEGIN_THREADS_DEF;

max_work = default_max_work;

if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O&|O", kwlist,
PyArray_Converter, &self,
PyArray_Converter, &other,
Expand Down Expand Up @@ -4043,17 +4046,29 @@ array_shares_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwd
Py_RETURN_TRUE;
}
else if (result == MEM_OVERLAP_OVERFLOW) {
PyErr_SetString(PyExc_OverflowError,
"Integer overflow in computing overlap");
return NULL;
if (raise_exceptions) {
PyErr_SetString(PyExc_OverflowError,
"Integer overflow in computing overlap");
return NULL;
}
else {
/* Don't know, so say yes */
Py_RETURN_TRUE;
}
}
else if (result == MEM_OVERLAP_TOO_HARD) {
npy_cache_import("numpy.core._internal", "TooHardError",
&too_hard_cls);
if (too_hard_cls) {
PyErr_SetString(too_hard_cls, "Exceeded max_work");
if (raise_exceptions) {
npy_cache_import("numpy.core._internal", "TooHardError",
&too_hard_cls);
if (too_hard_cls) {
PyErr_SetString(too_hard_cls, "Exceeded max_work");
}
return NULL;
}
else {
/* Don't know, so say yes */
Py_RETURN_TRUE;
}
return NULL;
}
else {
/* Doesn't happen usually */
Expand All @@ -4069,6 +4084,20 @@ array_shares_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwd
}


static PyObject *
array_shares_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
{
return array_shares_memory_impl(args, kwds, NPY_MAY_SHARE_EXACT, 1);
}


static PyObject *
array_may_share_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
{
return array_shares_memory_impl(args, kwds, NPY_MAY_SHARE_BOUNDS, 0);
}


static struct PyMethodDef array_module_methods[] = {
{"_get_ndarray_c_version",
(PyCFunction)array__get_ndarray_c_version,
Expand Down Expand Up @@ -4178,6 +4207,9 @@ static struct PyMethodDef array_module_methods[] = {
{"shares_memory",
(PyCFunction)array_shares_memory,
METH_VARARGS | METH_KEYWORDS, NULL},
{"may_share_memory",
(PyCFunction)array_may_share_memory,
METH_VARARGS | METH_KEYWORDS, NULL},
/* Datetime-related functions */
{"datetime_data",
(PyCFunction)array_datetime_data,
Expand Down

0 comments on commit 8efc87e

Please sign in to comment.