Skip to content

Commit

Permalink
Merge pull request numpy#19005 from BvB93/ndarray-meth
Browse files Browse the repository at this point in the history
ENH: Add dtype-support to 11 `ndarray` / `generic` methods
  • Loading branch information
charris authored May 17, 2021
2 parents a62d072 + d71e1e3 commit fdf5a0e
Show file tree
Hide file tree
Showing 6 changed files with 283 additions and 81 deletions.
145 changes: 115 additions & 30 deletions numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ from numpy.typing import (
_ComplexLike_co,
_TD64Like_co,
_NumberLike_co,
_ScalarLike_co,

# `number` precision
NBitBase,
Expand Down Expand Up @@ -1239,19 +1240,9 @@ class _ArrayOrScalarCommon:
def copy(self: _ArraySelf, order: _OrderKACF = ...) -> _ArraySelf: ...
def dump(self, file: str) -> None: ...
def dumps(self) -> bytes: ...
def flatten(self, order: _OrderKACF = ...) -> ndarray: ...
def getfield(
self: _ArraySelf, dtype: DTypeLike, offset: int = ...
) -> _ArraySelf: ...
def ravel(self, order: _OrderKACF = ...) -> ndarray: ...
@overload
def reshape(
self, __shape: _ShapeLike, *, order: _OrderACF = ...
) -> ndarray: ...
@overload
def reshape(
self, *shape: SupportsIndex, order: _OrderACF = ...
) -> ndarray: ...
def tobytes(self, order: _OrderKACF = ...) -> bytes: ...
# NOTE: `tostring()` is deprecated and therefore excluded
# def tostring(self, order=...): ...
Expand Down Expand Up @@ -1718,67 +1709,105 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
def itemset(self, __value: Any) -> None: ...
@overload
def itemset(self, __item: _ShapeLike, __value: Any) -> None: ...

@overload
def resize(self, __new_shape: _ShapeLike, *, refcheck: bool = ...) -> None: ...
@overload
def resize(self, *new_shape: SupportsIndex, refcheck: bool = ...) -> None: ...

def setflags(
self, write: bool = ..., align: bool = ..., uic: bool = ...
) -> None: ...

def squeeze(
self: _ArraySelf, axis: Union[SupportsIndex, Tuple[SupportsIndex, ...]] = ...
) -> _ArraySelf: ...
def swapaxes(self: _ArraySelf, axis1: SupportsIndex, axis2: SupportsIndex) -> _ArraySelf: ...
self,
axis: Union[SupportsIndex, Tuple[SupportsIndex, ...]] = ...,
) -> ndarray[Any, _DType_co]: ...

def swapaxes(
self,
axis1: SupportsIndex,
axis2: SupportsIndex,
) -> ndarray[Any, _DType_co]: ...

@overload
def transpose(self: _ArraySelf, __axes: _ShapeLike) -> _ArraySelf: ...
@overload
def transpose(self: _ArraySelf, *axes: SupportsIndex) -> _ArraySelf: ...

def argpartition(
self,
kth: _ArrayLikeInt_co,
axis: Optional[SupportsIndex] = ...,
kind: _PartitionKind = ...,
order: Union[None, str, Sequence[str]] = ...,
) -> ndarray: ...
) -> ndarray[Any, dtype[intp]]: ...

def diagonal(
self: _ArraySelf,
self,
offset: SupportsIndex = ...,
axis1: SupportsIndex = ...,
axis2: SupportsIndex = ...,
) -> _ArraySelf: ...
) -> ndarray[Any, _DType_co]: ...

# 1D + 1D returns a scalar;
# all other with at least 1 non-0D array return an ndarray.
@overload
def dot(self, b: _ScalarLike_co, out: None = ...) -> ndarray: ...
@overload
def dot(self, b: ArrayLike, out: None = ...) -> ndarray: ...
def dot(self, b: ArrayLike, out: None = ...) -> Any: ... # type: ignore[misc]
@overload
def dot(self, b: ArrayLike, out: _NdArraySubClass = ...) -> _NdArraySubClass: ...
def dot(self, b: ArrayLike, out: _NdArraySubClass) -> _NdArraySubClass: ...

# `nonzero()` is deprecated for 0d arrays/generics
def nonzero(self) -> Tuple[ndarray, ...]: ...
def nonzero(self) -> Tuple[ndarray[Any, dtype[intp]], ...]: ...

def partition(
self,
kth: _ArrayLikeInt_co,
axis: SupportsIndex = ...,
kind: _PartitionKind = ...,
order: Union[None, str, Sequence[str]] = ...,
) -> None: ...

# `put` is technically available to `generic`,
# but is pointless as `generic`s are immutable
def put(
self, ind: _ArrayLikeInt_co, v: ArrayLike, mode: _ModeKind = ...
self,
ind: _ArrayLikeInt_co,
v: ArrayLike,
mode: _ModeKind = ...,
) -> None: ...

@overload
def searchsorted( # type: ignore[misc]
self, # >= 1D array
v: _ScalarLike_co, # 0D array-like
side: _SortSide = ...,
sorter: Optional[_ArrayLikeInt_co] = ...,
) -> intp: ...
@overload
def searchsorted(
self, # >= 1D array
v: ArrayLike,
side: _SortSide = ...,
sorter: Optional[_ArrayLikeInt_co] = ..., # 1D int array
) -> ndarray: ...
sorter: Optional[_ArrayLikeInt_co] = ...,
) -> ndarray[Any, dtype[intp]]: ...

def setfield(
self, val: ArrayLike, dtype: DTypeLike, offset: SupportsIndex = ...
self,
val: ArrayLike,
dtype: DTypeLike,
offset: SupportsIndex = ...,
) -> None: ...

def sort(
self,
axis: SupportsIndex = ...,
kind: Optional[_SortKind] = ...,
order: Union[None, str, Sequence[str]] = ...,
) -> None: ...

@overload
def trace(
self, # >= 2D array
Expand Down Expand Up @@ -1829,17 +1858,46 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
axis: Optional[SupportsIndex] = ...,
) -> ndarray[Any, _DType_co]: ...

# Many of these special methods are irrelevant currently, since protocols
# aren't supported yet. That said, I'm adding them for completeness.
# https://docs.python.org/3/reference/datamodel.html
def __int__(self) -> int: ...
def __float__(self) -> float: ...
def __complex__(self) -> complex: ...
def flatten(
self,
order: _OrderKACF = ...,
) -> ndarray[Any, _DType_co]: ...

def ravel(
self,
order: _OrderKACF = ...,
) -> ndarray[Any, _DType_co]: ...

@overload
def reshape(
self, __shape: _ShapeLike, *, order: _OrderACF = ...
) -> ndarray[Any, _DType_co]: ...
@overload
def reshape(
self, *shape: SupportsIndex, order: _OrderACF = ...
) -> ndarray[Any, _DType_co]: ...

# Dispatch to the underlying `generic` via protocols
def __int__(
self: ndarray[Any, dtype[SupportsInt]], # type: ignore[type-var]
) -> int: ...

def __float__(
self: ndarray[Any, dtype[SupportsFloat]], # type: ignore[type-var]
) -> float: ...

def __complex__(
self: ndarray[Any, dtype[SupportsComplex]], # type: ignore[type-var]
) -> complex: ...

def __index__(
self: ndarray[Any, dtype[SupportsIndex]], # type: ignore[type-var]
) -> int: ...

def __len__(self) -> int: ...
def __setitem__(self, key, value): ...
def __iter__(self) -> Any: ...
def __contains__(self, key) -> bool: ...
def __index__(self) -> int: ...

# The last overload is for catching recursive objects whose
# nesting is too deep.
Expand Down Expand Up @@ -2827,6 +2885,25 @@ class generic(_ArrayOrScalarCommon):
axis: Optional[SupportsIndex] = ...,
) -> ndarray[Any, dtype[_ScalarType]]: ...

def flatten(
self: _ScalarType,
order: _OrderKACF = ...,
) -> ndarray[Any, dtype[_ScalarType]]: ...

def ravel(
self: _ScalarType,
order: _OrderKACF = ...,
) -> ndarray[Any, dtype[_ScalarType]]: ...

@overload
def reshape(
self: _ScalarType, __shape: _ShapeLike, *, order: _OrderACF = ...
) -> ndarray[Any, dtype[_ScalarType]]: ...
@overload
def reshape(
self: _ScalarType, *shape: SupportsIndex, order: _OrderACF = ...
) -> ndarray[Any, dtype[_ScalarType]]: ...

def squeeze(
self: _ScalarType, axis: Union[Literal[0], Tuple[()]] = ...
) -> _ScalarType: ...
Expand Down Expand Up @@ -2919,6 +2996,11 @@ class object_(generic):
def real(self: _ArraySelf) -> _ArraySelf: ...
@property
def imag(self: _ArraySelf) -> _ArraySelf: ...
# The 3 protocols below may or may not raise,
# depending on the underlying object
def __int__(self) -> int: ...
def __float__(self) -> float: ...
def __complex__(self) -> complex: ...

object0 = object_

Expand Down Expand Up @@ -3043,6 +3125,9 @@ class timedelta64(generic):
__value: Union[None, int, _CharLike_co, dt.timedelta, timedelta64] = ...,
__format: Union[_CharLike_co, Tuple[_CharLike_co, _IntLike_co]] = ...,
) -> None: ...

# NOTE: Only a limited number of units support conversion
# to builtin scalar types: `Y`, `M`, `ns`, `ps`, `fs`, `as`
def __int__(self) -> int: ...
def __float__(self) -> float: ...
def __complex__(self) -> complex: ...
Expand Down
9 changes: 9 additions & 0 deletions numpy/typing/tests/data/fail/ndarray_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
"""

from typing import Any
import numpy as np

f8: np.float64
AR_f8: np.ndarray[Any, np.dtype[np.float64]]
AR_M: np.ndarray[Any, np.dtype[np.datetime64]]
AR_b: np.ndarray[Any, np.dtype[np.bool_]]

f8.argpartition(0) # E: has no attribute
f8.diagonal() # E: has no attribute
Expand All @@ -19,3 +23,8 @@
f8.setfield(2, np.float64) # E: has no attribute
f8.sort() # E: has no attribute
f8.trace() # E: has no attribute

AR_M.__int__() # E: Invalid self argument
AR_M.__float__() # E: Invalid self argument
AR_M.__complex__() # E: Invalid self argument
AR_b.__index__() # E: Invalid self argument
21 changes: 21 additions & 0 deletions numpy/typing/tests/data/pass/ndarray_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

from __future__ import annotations

import operator
from typing import cast, Any

import numpy as np

class SubClass(np.ndarray): ...
Expand Down Expand Up @@ -162,3 +164,22 @@ class SubClass(np.ndarray): ...

A.item(0)
C.item(0)

A.ravel()
C.ravel()

A.flatten()
C.flatten()

A.reshape(1)
C.reshape(3)

int(np.array(1.0, dtype=np.float64))
int(np.array("1", dtype=np.str_))

float(np.array(1.0, dtype=np.float64))
float(np.array("1", dtype=np.str_))

complex(np.array(1.0, dtype=np.float64))

operator.index(np.array(1, dtype=np.int64))
62 changes: 47 additions & 15 deletions numpy/typing/tests/data/pass/scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
import pytest
import numpy as np

b = np.bool_()
u8 = np.uint64()
i8 = np.int64()
f8 = np.float64()
c16 = np.complex128()
U = np.str_()
S = np.bytes_()


# Construction
class D:
Expand Down Expand Up @@ -205,18 +213,42 @@ def __float__(self) -> float:
np.clongfloat()
np.longcomplex()

np.bool_().item()
np.int_().item()
np.uint64().item()
np.float32().item()
np.complex128().item()
np.str_().item()
np.bytes_().item()

np.bool_().tolist()
np.int_().tolist()
np.uint64().tolist()
np.float32().tolist()
np.complex128().tolist()
np.str_().tolist()
np.bytes_().tolist()
b.item()
i8.item()
u8.item()
f8.item()
c16.item()
U.item()
S.item()

b.tolist()
i8.tolist()
u8.tolist()
f8.tolist()
c16.tolist()
U.tolist()
S.tolist()

b.ravel()
i8.ravel()
u8.ravel()
f8.ravel()
c16.ravel()
U.ravel()
S.ravel()

b.flatten()
i8.flatten()
u8.flatten()
f8.flatten()
c16.flatten()
U.flatten()
S.flatten()

b.reshape(1)
i8.reshape(1)
u8.reshape(1)
f8.reshape(1)
c16.reshape(1)
U.reshape(1)
S.reshape(1)
Loading

0 comments on commit fdf5a0e

Please sign in to comment.