Skip to content

Commit

Permalink
ENH: refactor __array_function__ pure Python implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Dec 19, 2018
1 parent f4ddc2b commit 4541345
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 95 deletions.
6 changes: 3 additions & 3 deletions benchmarks/benchmarks/bench_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ def mock_broadcast_to(array, shape, subok=False):


def _concatenate_dispatcher(arrays, axis=None, out=None):
for array in arrays:
yield array
if out is not None:
yield out
arrays = list(arrays)
arrays.append(out)
return arrays


@array_function_dispatch(_concatenate_dispatcher)
Expand Down
16 changes: 8 additions & 8 deletions doc/neps/nep-0018-array-function-protocol.rst
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ Changes within NumPy functions
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Given a function defining the above behavior, for now call it
``array_function_implementation_or_override``, we now need to call that
``implement_array_function``, we now need to call that
function from within every relevant NumPy function. This is a pervasive change,
but of fairly simple and innocuous code that should complete quickly and
without effect if no arguments implement the ``__array_function__``
Expand All @@ -358,7 +358,7 @@ functions:
@functools.wraps(implementation)
def public_api(*args, **kwargs):
relevant_args = dispatcher(*args, **kwargs)
return array_function_implementation_or_override(
return implement_array_function(
implementation, public_api, relevant_args, args, kwargs)
return public_api
return decorator
Expand Down Expand Up @@ -395,11 +395,11 @@ It's particularly worth calling out the decorator's use of

In a few cases, it would not make sense to use the ``array_function_dispatch``
decorator directly, but override implementation in terms of
``array_function_implementation_or_override`` should still be straightforward.
``implement_array_function`` should still be straightforward.

- Functions written entirely in C (e.g., ``np.concatenate``) can't use
decorators, but they could still use a C equivalent of
``array_function_implementation_or_override``. If performance is not a
``implement_array_function``. If performance is not a
concern, they could also be easily wrapped with a small Python wrapper.
- ``np.einsum`` does complicated argument parsing to handle two different
function signatures. It would probably be best to avoid the overhead of
Expand Down Expand Up @@ -475,7 +475,7 @@ the difference in speed between the ``ndarray.sum()`` method (1.6 us) and
``numpy.sum()`` function (2.6 us).

Fortunately, we expect significantly less overhead with a C implementation of
``array_function_implementation_or_override``, which is where the bulk of the
``implement_array_function``, which is where the bulk of the
runtime is. This would leave the ``array_function_dispatch`` decorator and
dispatcher function on their own adding about 0.5 microseconds of overhead,
for perhaps ~1 microsecond of overhead in the typical case.
Expand Down Expand Up @@ -503,7 +503,7 @@ already wrap a limited subset of SciPy functionality (e.g.,

If we want to do this, we should expose at least the decorator
``array_function_dispatch()`` and possibly also the lower level
``array_function_implementation_or_override()`` as part of NumPy's public API.
``implement_array_function()`` as part of NumPy's public API.

Non-goals
---------
Expand Down Expand Up @@ -807,7 +807,7 @@ public API.

``types`` is included because we can compute it almost for free as part of
collecting ``__array_function__`` implementations to call in
``array_function_implementation_or_override``. We also think it will be used
``implement_array_function``. We also think it will be used
by many ``__array_function__`` methods, which otherwise would need to extract
this information themselves. It would be equivalently easy to provide single
instances of each type, but providing only types seemed cleaner.
Expand All @@ -823,7 +823,7 @@ There are two other arguments that we think *might* be important to pass to
- Access to the non-dispatched implementation (i.e., before wrapping with
``array_function_dispatch``) in ``ndarray.__array_function__`` would allow
us to drop special case logic for that method from
``array_function_implementation_or_override``.
``implement_array_function``.
- Access to the ``dispatcher`` function passed into
``array_function_dispatch()`` would allow ``__array_function__``
implementations to determine the list of "array-like" arguments in a generic
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _array_function(self, func, types, args, kwargs):
# TODO: rewrite this in C
# Cannot handle items that have __array_function__ other than our own.
for t in types:
if not issubclass(t, mu.ndarray) and hasattr(t, '__array_function__'):
if not issubclass(t, mu.ndarray):
return NotImplemented

# The regular implementation can handle this, so we call it directly.
Expand Down
8 changes: 5 additions & 3 deletions numpy/core/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,11 @@ def concatenate(arrays, axis=None, out=None):
fill_value=999999)
"""
for array in arrays:
yield array
yield out
if out is not None:
# optimize for the typical case where only arrays is provided
arrays = list(arrays)
arrays.append(out)
return arrays


@array_function_from_c_func_and_dispatcher(_multiarray_umath.inner)
Expand Down
108 changes: 69 additions & 39 deletions numpy/core/overrides.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Preliminary implementation of NEP-18
"""Preliminary implementation of NEP-18.
TODO: rewrite this in C for performance.
"""
Expand All @@ -10,64 +10,80 @@
from numpy.compat._inspect import getargspec


_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
_NDARRAY_ONLY = [ndarray]

ENABLE_ARRAY_FUNCTION = bool(
int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0)))


def get_overloaded_types_and_args(relevant_args):
def get_implementing_types_and_args(relevant_args):
"""Returns a list of arguments on which to call __array_function__.
Parameters
----------
relevant_args : iterable of array-like
Iterable of array-like arguments to check for __array_function__
methods.
Returns
-------
overloaded_types : collection of types
implementing_types : collection of types
Types of arguments from relevant_args with __array_function__ methods.
overloaded_args : list
implementing_args : list
Arguments from relevant_args on which to call __array_function__
methods, in the order in which they should be called.
"""
# Runtime is O(num_arguments * num_unique_types)
overloaded_types = []
overloaded_args = []
implementing_types = []
implementing_args = []
for arg in relevant_args:
arg_type = type(arg)
# We only collect arguments if they have a unique type, which ensures
# reasonable performance even with a long list of possibly overloaded
# arguments.
if (arg_type not in overloaded_types and
if (arg_type not in implementing_types and
hasattr(arg_type, '__array_function__')):

# Create lists explicitly for the first type (usually the only one
# done) to avoid setting up the iterator for overloaded_args.
if overloaded_types:
overloaded_types.append(arg_type)
# done) to avoid setting up the iterator for implementing_args.
if implementing_types:
implementing_types.append(arg_type)
# By default, insert argument at the end, but if it is
# subclass of another argument, insert it before that argument.
# This ensures "subclasses before superclasses".
index = len(overloaded_args)
for i, old_arg in enumerate(overloaded_args):
index = len(implementing_args)
for i, old_arg in enumerate(implementing_args):
if issubclass(arg_type, type(old_arg)):
index = i
break
overloaded_args.insert(index, arg)
implementing_args.insert(index, arg)
else:
overloaded_types = [arg_type]
overloaded_args = [arg]
implementing_types = [arg_type]
implementing_args = [arg]

return implementing_types, implementing_args


_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__


return overloaded_types, overloaded_args
def any_overrides(relevant_args):
"""Are there any __array_function__ methods that need to be called?"""
for arg in relevant_args:
arg_type = type(arg)
if (arg_type is not ndarray and
getattr(arg_type, '__array_function__',
_NDARRAY_ARRAY_FUNCTION)
is not _NDARRAY_ARRAY_FUNCTION):
return True
return False


def array_function_implementation_or_override(
_TUPLE_OR_LIST = {tuple, list}


def implement_array_function(
implementation, public_api, relevant_args, args, kwargs):
"""Implement a function with checks for __array_function__ overrides.
"""
Implement a function with checks for __array_function__ overrides.
All arguments are required, and can only be passed by position.
Arguments
---------
Expand All @@ -82,41 +98,55 @@ def array_function_implementation_or_override(
Iterable of arguments to check for __array_function__ methods.
args : tuple
Arbitrary positional arguments originally passed into ``public_api``.
kwargs : tuple
kwargs : dict
Arbitrary keyword arguments originally passed into ``public_api``.
Returns
-------
Result from calling `implementation()` or an `__array_function__`
Result from calling ``implementation()`` or an ``__array_function__``
method, as appropriate.
Raises
------
TypeError : if no implementation is found.
"""
# Check for __array_function__ methods.
types, overloaded_args = get_overloaded_types_and_args(relevant_args)
# Short-cut for common cases: no overload or only ndarray overload
# (directly or with subclasses that do not override __array_function__).
if (not overloaded_args or types == _NDARRAY_ONLY or
all(type(arg).__array_function__ is _NDARRAY_ARRAY_FUNCTION
for arg in overloaded_args)):
if type(relevant_args) not in _TUPLE_OR_LIST:
relevant_args = tuple(relevant_args)

if not any_overrides(relevant_args):
return implementation(*args, **kwargs)

# Call overrides
for overloaded_arg in overloaded_args:
types, implementing_args = get_implementing_types_and_args(relevant_args)
for arg in implementing_args:
# Use `public_api` instead of `implemenation` so __array_function__
# implementations can do equality/identity comparisons.
result = overloaded_arg.__array_function__(
public_api, types, args, kwargs)

result = arg.__array_function__(public_api, types, args, kwargs)
if result is not NotImplemented:
return result

func_name = '{}.{}'.format(public_api.__module__, public_api.__name__)
raise TypeError("no implementation found for '{}' on types that implement "
'__array_function__: {}'
.format(func_name, list(map(type, overloaded_args))))
'__array_function__: {}'.format(func_name, list(types)))


def _get_implementing_args(relevant_args):
"""
Collect arguments on which to call __array_function__.
Parameters
----------
relevant_args : iterable of array-like
Iterable of possibly array-like arguments to check for
__array_function__ methods.
Returns
-------
Sequence of arguments with __array_function__ methods, in the order in
which they should be called.
"""
_, args = get_implementing_types_and_args(relevant_args)
return args


ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
Expand Down Expand Up @@ -215,7 +245,7 @@ def decorator(implementation):
@functools.wraps(implementation)
def public_api(*args, **kwargs):
relevant_args = dispatcher(*args, **kwargs)
return array_function_implementation_or_override(
return implement_array_function(
implementation, public_api, relevant_args, args, kwargs)

if module is not None:
Expand Down
7 changes: 4 additions & 3 deletions numpy/core/shape_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,11 @@ def hstack(tup):

def _stack_dispatcher(arrays, axis=None, out=None):
arrays = _arrays_for_stack_dispatcher(arrays, stacklevel=6)
for a in arrays:
yield a
if out is not None:
yield out
# optimize for the typical case where only arrays is provided
arrays = list(arrays)
arrays.append(out)
return arrays


@array_function_dispatch(_stack_dispatcher)
Expand Down
Loading

0 comments on commit 4541345

Please sign in to comment.