Skip to content

Commit

Permalink
Merge pull request numpy#25145 from mtsokol/array-api-cross
Browse files Browse the repository at this point in the history
API: Add `cross` to `numpy.linalg` [Array API]
  • Loading branch information
ngoldbaum authored Dec 5, 2023
2 parents eabb962 + cd9f69a commit be4e25f
Show file tree
Hide file tree
Showing 12 changed files with 134 additions and 22 deletions.
6 changes: 6 additions & 0 deletions doc/release/upcoming_changes/25145.new_feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
``cross`` for `numpy.linalg`
----------------------------

`numpy.linalg.cross` has been added. It computes the cross product of two
(arrays of) 3-dimensional vectors. It differs from `numpy.cross` by accepting
three-dimensional vectors only. This function is compatible with Array API.
1 change: 0 additions & 1 deletion doc/source/reference/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ Function instead of method
These functions are in the ``linalg`` sub-namespace in the array API, but are
only in the top-level namespace in NumPy:

- ``cross``
- ``diagonal``
- ``matmul`` (*)
- ``outer``
Expand Down
1 change: 1 addition & 0 deletions doc/source/reference/routines.linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Matrix and vector products
einsum_path
linalg.matrix_power
kron
linalg.cross

Decompositions
--------------
Expand Down
2 changes: 2 additions & 0 deletions numpy/_core/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1510,6 +1510,8 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
--------
inner : Inner product
outer : Outer product.
linalg.cross : An Array API compatible variation of ``np.cross``,
which accepts (arrays of) 3-element vectors only.
ix_ : Construct index arrays.
Notes
Expand Down
28 changes: 14 additions & 14 deletions numpy/_core/numeric.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -488,62 +488,62 @@ def moveaxis(

@overload
def cross(
a: _ArrayLikeUnknown,
b: _ArrayLikeUnknown,
x1: _ArrayLikeUnknown,
x2: _ArrayLikeUnknown,
axisa: int = ...,
axisb: int = ...,
axisc: int = ...,
axis: None | int = ...,
) -> NDArray[Any]: ...
@overload
def cross(
a: _ArrayLikeBool_co,
b: _ArrayLikeBool_co,
x1: _ArrayLikeBool_co,
x2: _ArrayLikeBool_co,
axisa: int = ...,
axisb: int = ...,
axisc: int = ...,
axis: None | int = ...,
) -> NoReturn: ...
@overload
def cross(
a: _ArrayLikeUInt_co,
b: _ArrayLikeUInt_co,
x1: _ArrayLikeUInt_co,
x2: _ArrayLikeUInt_co,
axisa: int = ...,
axisb: int = ...,
axisc: int = ...,
axis: None | int = ...,
) -> NDArray[unsignedinteger[Any]]: ...
@overload
def cross(
a: _ArrayLikeInt_co,
b: _ArrayLikeInt_co,
x1: _ArrayLikeInt_co,
x2: _ArrayLikeInt_co,
axisa: int = ...,
axisb: int = ...,
axisc: int = ...,
axis: None | int = ...,
) -> NDArray[signedinteger[Any]]: ...
@overload
def cross(
a: _ArrayLikeFloat_co,
b: _ArrayLikeFloat_co,
x1: _ArrayLikeFloat_co,
x2: _ArrayLikeFloat_co,
axisa: int = ...,
axisb: int = ...,
axisc: int = ...,
axis: None | int = ...,
) -> NDArray[floating[Any]]: ...
@overload
def cross(
a: _ArrayLikeComplex_co,
b: _ArrayLikeComplex_co,
x1: _ArrayLikeComplex_co,
x2: _ArrayLikeComplex_co,
axisa: int = ...,
axisb: int = ...,
axisc: int = ...,
axis: None | int = ...,
) -> NDArray[complexfloating[Any, Any]]: ...
@overload
def cross(
a: _ArrayLikeObject_co,
b: _ArrayLikeObject_co,
x1: _ArrayLikeObject_co,
x2: _ArrayLikeObject_co,
axisa: int = ...,
axisb: int = ...,
axisc: int = ...,
Expand Down
1 change: 1 addition & 0 deletions numpy/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Matrix and vector products
--------------------------
cross
multi_dot
matrix_power
Expand Down
1 change: 1 addition & 0 deletions numpy/linalg/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ from numpy.linalg._linalg import (
multi_dot as multi_dot,
trace as trace,
diagonal as diagonal,
cross as cross,
)

from numpy._pytesttester import PytestTester
Expand Down
61 changes: 57 additions & 4 deletions numpy/linalg/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
__all__ = ['matrix_power', 'solve', 'tensorsolve', 'tensorinv', 'inv',
'cholesky', 'eigvals', 'eigvalsh', 'pinv', 'slogdet', 'det',
'svd', 'eig', 'eigh', 'lstsq', 'norm', 'qr', 'cond', 'matrix_rank',
'LinAlgError', 'multi_dot', 'trace', 'diagonal']
'LinAlgError', 'multi_dot', 'trace', 'diagonal', 'cross']

import functools
import operator
Expand All @@ -26,7 +26,8 @@
add, multiply, sqrt, sum, isfinite, finfo, errstate, moveaxis, amin,
amax, prod, abs, atleast_2d, intp, asanyarray, object_, matmul,
swapaxes, divide, count_nonzero, isnan, sign, argsort, sort,
reciprocal, overrides, diagonal as _core_diagonal, trace as _core_trace
reciprocal, overrides, diagonal as _core_diagonal, trace as _core_trace,
cross as _core_cross,
)
from numpy.lib._twodim_base_impl import triu, eye
from numpy.lib.array_utils import normalize_axis_index
Expand Down Expand Up @@ -2937,14 +2938,14 @@ def diagonal(x, /, *, offset=0):
See Also
--------
numpy.diagonal
"""
return _core_diagonal(x, offset, axis1=-2, axis2=-1)


# trace

def _trace_dispatcher(
x, /, *, offset=None, dtype=None):
def _trace_dispatcher(x, /, *, offset=None, dtype=None):
return (x,)


Expand Down Expand Up @@ -2990,5 +2991,57 @@ def trace(x, /, *, offset=0, dtype=None):
See Also
--------
numpy.trace
"""
return _core_trace(x, offset, axis1=-2, axis2=-1, dtype=dtype)


# cross

def _cross_dispatcher(x1, x2, /, *, axis=None):
return (x1, x2,)


@array_function_dispatch(_cross_dispatcher)
def cross(x1, x2, /, *, axis=-1):
"""
Returns the cross product of 3-element vectors.
If ``x1`` and/or ``x2`` are multi-dimensional arrays, then
the cross-product of each pair of corresponding 3-element vectors
is independently computed.
This function is Array API compatible, contrary to
:func:`numpy.cross`.
Parameters
----------
x1 : array_like
The first input array.
x2 : array_like
The second input array. Must be compatible with ``x1`` for all
non-compute axes. The size of the axis over which to compute
the cross-product must be the same size as the respective axis
in ``x1``.
axis : int, optional
The axis (dimension) of ``x1`` and ``x2`` containing the vectors for
which to compute the cross-product. Default: ``-1``.
Returns
-------
out : ndarray
An array containing the cross products.
See Also
--------
numpy.cross
"""
if x1.shape[axis] != 3 or x2.shape[axis] != 3:
raise ValueError(
"Both input arrays must be (arrays of) 3-dimensional vectors, "
f"but they are {x1.shape[axis]} and {x2.shape[axis]} "
"dimensional instead."
)

return _core_cross(x1, x2, axis=axis)
28 changes: 28 additions & 0 deletions numpy/linalg/_linalg.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ from numpy import (
generic,
floating,
complexfloating,
signedinteger,
unsignedinteger,
int32,
float64,
complex128,
Expand All @@ -25,6 +27,7 @@ from numpy._typing import (
NDArray,
ArrayLike,
_ArrayLikeInt_co,
_ArrayLikeUInt_co,
_ArrayLikeFloat_co,
_ArrayLikeComplex_co,
_ArrayLikeTD64_co,
Expand Down Expand Up @@ -307,3 +310,28 @@ def trace(
offset: SupportsIndex = ...,
dtype: DTypeLike = ...,
) -> Any: ...

@overload
def cross(
a: _ArrayLikeUInt_co,
b: _ArrayLikeUInt_co,
axis: int = ...,
) -> NDArray[unsignedinteger[Any]]: ...
@overload
def cross(
a: _ArrayLikeInt_co,
b: _ArrayLikeInt_co,
axis: int = ...,
) -> NDArray[signedinteger[Any]]: ...
@overload
def cross(
a: _ArrayLikeFloat_co,
b: _ArrayLikeFloat_co,
axis: int = ...,
) -> NDArray[floating[Any]]: ...
@overload
def cross(
a: _ArrayLikeComplex_co,
b: _ArrayLikeComplex_co,
axis: int = ...,
) -> NDArray[complexfloating[Any, Any]]: ...
20 changes: 20 additions & 0 deletions numpy/linalg/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2228,3 +2228,23 @@ def test_trace():
expected = np.array([36, 116, 196])

assert_equal(actual, expected)


def test_cross():

x = np.arange(9).reshape((3, 3))
actual = np.linalg.cross(x, x + 1)
expected = np.array([
[-1, 2, -1],
[-1, 2, -1],
[-1, 2, -1],
])

assert_equal(actual, expected)

with assert_raises_regex(
ValueError,
r"input arrays must be \(arrays of\) 3-dimensional vectors"
):
x_2dim = x[:, 1:]
np.linalg.cross(x_2dim, x_2dim)
4 changes: 4 additions & 0 deletions numpy/typing/tests/data/reveal/linalg.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,7 @@ assert_type(np.linalg.multi_dot([AR_i8, AR_f8]), Any)
assert_type(np.linalg.multi_dot([AR_f8, AR_c16]), Any)
assert_type(np.linalg.multi_dot([AR_O, AR_O]), Any)
assert_type(np.linalg.multi_dot([AR_m, AR_m]), Any)

assert_type(np.linalg.cross(AR_i8, AR_i8), npt.NDArray[np.signedinteger[Any]])
assert_type(np.linalg.cross(AR_f8, AR_f8), npt.NDArray[np.floating[Any]])
assert_type(np.linalg.cross(AR_c16, AR_c16), npt.NDArray[np.complexfloating[Any, Any]])
3 changes: 0 additions & 3 deletions tools/ci/array-api-skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ array_api_tests/test_data_type_functions.py::test_isdtype
array_api_tests/test_data_type_functions.py::test_astype

# missing names
array_api_tests/test_has_names.py::test_has_names[linalg-cross]
array_api_tests/test_has_names.py::test_has_names[linalg-matmul]
array_api_tests/test_has_names.py::test_has_names[linalg-matrix_norm]
array_api_tests/test_has_names.py::test_has_names[linalg-matrix_transpose]
Expand Down Expand Up @@ -76,7 +75,6 @@ array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]

# missing linalg names
array_api_tests/test_linalg.py::test_cross
array_api_tests/test_linalg.py::test_matrix_norm
array_api_tests/test_linalg.py::test_matrix_transpose
array_api_tests/test_linalg.py::test_outer
Expand Down Expand Up @@ -125,7 +123,6 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift]
array_api_tests/test_signatures.py::test_func_signature[pow]
array_api_tests/test_signatures.py::test_func_signature[matrix_transpose]
array_api_tests/test_signatures.py::test_func_signature[vecdot]
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cross]
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matmul]
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cholesky]
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_norm]
Expand Down

0 comments on commit be4e25f

Please sign in to comment.