Skip to content

Commit

Permalink
Merge pull request scipy#15497 from rgommers/update-uarray
Browse files Browse the repository at this point in the history
  • Loading branch information
tupui authored Feb 1, 2022
2 parents d7283ff + ba12a99 commit 9f8f0f9
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 43 deletions.
5 changes: 2 additions & 3 deletions scipy/_lib/_uarray/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
.. note::
.. note:
If you are looking for overrides for NumPy-specific methods, see the
documentation for :obj:`unumpy`. This page explains how to write
back-ends and multimethods.
Expand Down Expand Up @@ -113,5 +113,4 @@
"""

from ._backend import *

__version__ = '0.8.2+14.gaf53966.scipy'
__version__ = '0.8.8.dev0+aa94c5a4.scipy'
57 changes: 33 additions & 24 deletions scipy/_lib/_uarray/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def get_state():
See Also
--------
set_state : Sets the state returned by this function.
set_state
Sets the state returned by this function.
"""
return _uarray.get_state()

Expand All @@ -122,8 +123,10 @@ def reset_state():
See Also
--------
set_state : Context manager that sets the backend state.
get_state : Gets a state to be set by this context manager.
set_state
Context manager that sets the backend state.
get_state
Gets a state to be set by this context manager.
"""
with set_state(get_state()):
yield
Expand All @@ -136,7 +139,8 @@ def set_state(state):
See Also
--------
get_state : Gets a state to be set by this context manager.
get_state
Gets a state to be set by this context manager.
"""
old_state = get_state()
_uarray.set_state(state)
Expand All @@ -157,7 +161,8 @@ def create_multimethod(*args, **kwargs):
See Also
--------
generate_multimethod : Generates a multimethod.
generate_multimethod
Generates a multimethod.
"""

def wrapper(a):
Expand Down Expand Up @@ -186,7 +191,7 @@ def generate_multimethod(
return an (args, kwargs) pair with the dispatchables replaced inside the args/kwargs.
domain : str
A string value indicating the domain of this multimethod.
default : Optional[Callable], optional
default: Optional[Callable], optional
The default implementation of this multimethod, where ``None`` (the default) specifies
there is no default implementation.
Expand Down Expand Up @@ -225,7 +230,7 @@ def generate_multimethod(
See Also
--------
uarray :
uarray
See the module documentation for how to override the method by creating backends.
"""
kw_defaults, arg_defaults, opts = get_defaults(argument_extractor)
Expand Down Expand Up @@ -256,8 +261,8 @@ def set_backend(backend, coerce=False, only=False):
See Also
--------
skip_backend : A context manager that allows skipping of backends.
set_global_backend : Set a single, global backend for a domain.
skip_backend: A context manager that allows skipping of backends.
set_global_backend: Set a single, global backend for a domain.
"""
try:
return backend.__ua_cache__["set", coerce, only]
Expand All @@ -284,8 +289,8 @@ def skip_backend(backend):
See Also
--------
set_backend : A context manager that allows setting of backends.
set_global_backend : Set a single, global backend for a domain.
set_backend: A context manager that allows setting of backends.
set_global_backend: Set a single, global backend for a domain.
"""
try:
return backend.__ua_cache__["skip"]
Expand Down Expand Up @@ -346,8 +351,8 @@ def set_global_backend(backend, coerce=False, only=False, *, try_last=False):
See Also
--------
set_backend : A context manager that allows setting of backends.
skip_backend : A context manager that allows skipping of backends.
set_backend: A context manager that allows setting of backends.
skip_backend: A context manager that allows skipping of backends.
"""
_uarray.set_global_backend(backend, coerce, only, try_last)

Expand Down Expand Up @@ -549,25 +554,27 @@ def determine_backend(value, dispatch_type, *, domain, only=True, coerce=False):
dispatch_type
The dispatch type associated with ``value``, aka
":ref:`marking <MarkingGlossary>`".
domain : string
domain: string
The domain to query for backends and set.
coerce : bool
coerce: bool
Whether or not to allow coercion to the backend's types. Implies ``only``.
only : bool
only: bool
Whether or not this should be the last backend to try.
See Also
--------
set_backend : For when you know which backend to set
set_backend: For when you know which backend to set
Notes
-----
Support is determined by the ``__ua_convert__`` protocol. Backends not
supporting the type must return ``NotImplemented`` from their
``__ua_convert__`` if they don't support input of that type.
Examples
--------
Suppose we have two backends ``BackendA`` and ``BackendB`` each supporting
different types, ``TypeA`` and ``TypeB``. Neither supporting the other type:
Expand Down Expand Up @@ -619,31 +626,33 @@ def determine_backend_multi(
Parameters
----------
dispatchables : Sequence[Union[uarray.Dispatchable, Any]]
dispatchables: Sequence[Union[uarray.Dispatchable, Any]]
The dispatchables that must be supported
domain : string
domain: string
The domain to query for backends and set.
coerce : bool
coerce: bool
Whether or not to allow coercion to the backend's types. Implies ``only``.
only : bool
only: bool
Whether or not this should be the last backend to try.
dispatch_type : Optional[Any]
dispatch_type: Optional[Any]
The default dispatch type associated with ``dispatchables``, aka
":ref:`marking <MarkingGlossary>`".
See Also
--------
determine_backend : For a single dispatch value
set_backend : For when you know which backend to set
determine_backend: For a single dispatch value
set_backend: For when you know which backend to set
Notes
-----
Support is determined by the ``__ua_convert__`` protocol. Backends not
supporting the type must return ``NotImplemented`` from their
``__ua_convert__`` if they don't support input of that type.
Examples
--------
:func:`determine_backend` allows the backend to be set from a single
object. :func:`determine_backend_multi` allows multiple objects to be
checked simultaneously for support in the backend. Suppose we have a
Expand Down
30 changes: 16 additions & 14 deletions scipy/_lib/_uarray/_uarray_dispatch.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1283,8 +1283,9 @@ PyObject * Function::call(PyObject * args_, PyObject * kwargs_) {
if (new_args.args == nullptr)
return LoopReturn::Error;

PyObject * args[] = {backend, reinterpret_cast<PyObject *>(this),
new_args.args.get(), new_args.kwargs.get()};
PyObject * args[] = {
backend, reinterpret_cast<PyObject *>(this), new_args.args.get(),
new_args.kwargs.get()};
result = py_ref::steal(Q_PyObject_VectorcallMethod(
identifiers.ua_function.get(), args,
array_size(args) | Q_PY_VECTORCALL_ARGUMENTS_OFFSET, nullptr));
Expand Down Expand Up @@ -1560,9 +1561,9 @@ PyObject * determine_backend(PyObject * /*self*/, PyObject * args) {
return LoopReturn::Continue;
}

PyObject * convert_args[] = {backend, dispatchables_tuple.get(),
(coerce && coerce_backend) ? Py_True
: Py_False};
PyObject * convert_args[] = {
backend, dispatchables_tuple.get(),
(coerce && coerce_backend) ? Py_True : Py_False};

auto res = py_ref::steal(Q_PyObject_VectorcallMethod(
identifiers.ua_convert.get(), convert_args,
Expand Down Expand Up @@ -1761,15 +1762,16 @@ PyMethodDef method_defs[] = {
{NULL} /* Sentinel */
};

PyModuleDef uarray_module = {PyModuleDef_HEAD_INIT,
/* m_name= */ "uarray._uarray",
/* m_doc= */ nullptr,
/* m_size= */ -1,
/* m_methods= */ method_defs,
/* m_slots= */ nullptr,
/* m_traverse= */ globals_traverse,
/* m_clear= */ globals_clear,
/* m_free= */ globals_free};
PyModuleDef uarray_module = {
PyModuleDef_HEAD_INIT,
/* m_name= */ "uarray._uarray",
/* m_doc= */ nullptr,
/* m_size= */ -1,
/* m_methods= */ method_defs,
/* m_slots= */ nullptr,
/* m_traverse= */ globals_traverse,
/* m_clear= */ globals_clear,
/* m_free= */ globals_free};

} // namespace

Expand Down
7 changes: 5 additions & 2 deletions scipy/_lib/_uarray/vectorcall.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ extern "C" {
#endif

// True if python supports vectorcall on custom classes
#define Q_HAS_VECTORCALL \
(!defined(PYPY_VERSION) && (PY_VERSION_HEX >= 0x03080000))
#if (!defined(PYPY_VERSION) && (PY_VERSION_HEX >= 0x03080000))
# define Q_HAS_VECTORCALL 1
#else
# define Q_HAS_VECTORCALL 0
#endif

#if !Q_HAS_VECTORCALL
# define Q_Py_TPFLAGS_HAVE_VECTORCALL 0
Expand Down

0 comments on commit 9f8f0f9

Please sign in to comment.