Skip to content

Commit

Permalink
TYP: 1-d shape-typing for ndarray.flatten and ravel
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham committed Nov 13, 2024
1 parent 7bfdc8c commit 547deac
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 53 deletions.
35 changes: 12 additions & 23 deletions numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2190,17 +2190,8 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
axis: None | SupportsIndex = ...,
) -> ndarray[_Shape, _DType_co]: ...

# TODO: use `tuple[int]` as shape type once covariant (#26081)
def flatten(
self,
order: _OrderKACF = ...,
) -> ndarray[_Shape, _DType_co]: ...

# TODO: use `tuple[int]` as shape type once covariant (#26081)
def ravel(
self,
order: _OrderKACF = ...,
) -> ndarray[_Shape, _DType_co]: ...
def flatten(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], _DType_co]: ...
def ravel(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], _DType_co]: ...

@overload
def reshape(
Expand Down Expand Up @@ -3100,11 +3091,10 @@ _NBit_fc = TypeVar("_NBit_fc", _NBitHalf, _NBitSingle, _NBitDouble, _NBitLongDou
class generic(_ArrayOrScalarCommon):
@abstractmethod
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
# TODO: use `tuple[()]` as shape type once covariant (#26081)
@overload
def __array__(self, dtype: None = ..., /) -> NDArray[Self]: ...
def __array__(self, dtype: None = None, /) -> ndarray[tuple[()], dtype[Self]]: ...
@overload
def __array__(self, dtype: _DType, /) -> ndarray[_Shape, _DType]: ...
def __array__(self, dtype: _DType, /) -> ndarray[tuple[()], _DType]: ...
def __hash__(self) -> int: ...
@property
def base(self) -> None: ...
Expand All @@ -3118,7 +3108,7 @@ class generic(_ArrayOrScalarCommon):
def strides(self) -> tuple[()]: ...
def byteswap(self, inplace: L[False] = ...) -> Self: ...
@property
def flat(self) -> flatiter[NDArray[Self]]: ...
def flat(self) -> flatiter[ndarray[tuple[int], dtype[Self]]]: ...

if sys.version_info >= (3, 12):
def __buffer__(self, flags: int, /) -> memoryview: ...
Expand Down Expand Up @@ -3202,8 +3192,8 @@ class generic(_ArrayOrScalarCommon):
) -> _NdArraySubClass: ...

def repeat(self, repeats: _ArrayLikeInt_co, axis: None | SupportsIndex = ...) -> NDArray[Self]: ...
def flatten(self, order: _OrderKACF = ...) -> NDArray[Self]: ...
def ravel(self, order: _OrderKACF = ...) -> NDArray[Self]: ...
def flatten(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], dtype[Self]]: ...
def ravel(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], dtype[Self]]: ...

@overload
def reshape(self, shape: _ShapeLike, /, *, order: _OrderACF = ...) -> NDArray[Self]: ...
Expand Down Expand Up @@ -4492,13 +4482,12 @@ class poly1d:
@coefficients.setter
def coefficients(self, value: NDArray[Any]) -> None: ...

__hash__: ClassVar[None] # type: ignore
__hash__: ClassVar[None] # type: ignore[assignment] # pyright: ignore[reportIncompatibleMethodOverride]

# TODO: use `tuple[int]` as shape type once covariant (#26081)
@overload
def __array__(self, t: None = ..., copy: None | bool = ...) -> NDArray[Any]: ...
def __array__(self, /, t: None = None, copy: builtins.bool | None = None) -> ndarray[tuple[int], dtype[Any]]: ...
@overload
def __array__(self, t: _DType, copy: None | bool = ...) -> ndarray[_Shape, _DType]: ...
def __array__(self, /, t: _DType, copy: builtins.bool | None = None) -> ndarray[tuple[int], _DType]: ...

@overload
def __call__(self, val: _ScalarLike_co) -> Any: ...
Expand Down Expand Up @@ -4668,8 +4657,8 @@ class matrix(ndarray[_Shape2DType_co, _DType_co]):

def squeeze(self, axis: None | _ShapeLike = ...) -> matrix[_Shape2D, _DType_co]: ...
def tolist(self: _SupportsItem[_T]) -> list[list[_T]]: ...
def ravel(self, order: _OrderKACF = ...) -> matrix[_Shape2D, _DType_co]: ...
def flatten(self, order: _OrderKACF = ...) -> matrix[_Shape2D, _DType_co]: ...
def ravel(self, /, order: _OrderKACF = "C") -> matrix[tuple[L[1], int], _DType_co]: ... # pyright: ignore[reportIncompatibleMethodOverride]
def flatten(self, /, order: _OrderKACF = "C") -> matrix[tuple[L[1], int], _DType_co]: ... # pyright: ignore[reportIncompatibleMethodOverride]

@property
def T(self) -> matrix[_Shape2D, _DType_co]: ...
Expand Down
22 changes: 20 additions & 2 deletions numpy/_core/fromnumeric.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ from numpy._typing import (
ArrayLike,
_ArrayLike,
NDArray,
_NestedSequence,
_ShapeLike,
_ArrayLikeBool_co,
_ArrayLikeUInt_co,
Expand Down Expand Up @@ -438,10 +439,27 @@ def trace(
out: _ArrayType = ...,
) -> _ArrayType: ...

_Array1D: TypeAlias = np.ndarray[tuple[int], np.dtype[_SCT]]

@overload
def ravel(a: _ArrayLike[_SCT], order: _OrderKACF = "C") -> _Array1D[_SCT]: ...
@overload
def ravel(a: bytes | _NestedSequence[bytes], order: _OrderKACF = "C") -> _Array1D[np.bytes_]: ...
@overload
def ravel(a: str | _NestedSequence[str], order: _OrderKACF = "C") -> _Array1D[np.str_]: ...
@overload
def ravel(a: bool | _NestedSequence[bool], order: _OrderKACF = "C") -> _Array1D[np.bool]: ...
@overload
def ravel(a: int | _NestedSequence[int], order: _OrderKACF = "C") -> _Array1D[np.int_ | np.bool]: ...
@overload
def ravel(a: float | _NestedSequence[float], order: _OrderKACF = "C") -> _Array1D[np.float64 | np.int_ | np.bool]: ...
@overload
def ravel(a: _ArrayLike[_SCT], order: _OrderKACF = ...) -> NDArray[_SCT]: ...
def ravel(
a: complex | _NestedSequence[complex],
order: _OrderKACF = "C",
) -> _Array1D[np.complex128 | np.float64 | np.int_ | np.bool]: ...
@overload
def ravel(a: ArrayLike, order: _OrderKACF = ...) -> NDArray[Any]: ...
def ravel(a: ArrayLike, order: _OrderKACF = "C") -> np.ndarray[tuple[int], np.dtype[Any]]: ...

@overload
def nonzero(a: np.generic | np.ndarray[tuple[()], Any]) -> NoReturn: ...
Expand Down
10 changes: 5 additions & 5 deletions numpy/typing/tests/data/reveal/fromnumeric.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ assert_type(np.trace(AR_b), Any)
assert_type(np.trace(AR_f4), Any)
assert_type(np.trace(AR_f4, out=AR_subclass), NDArraySubclass)

assert_type(np.ravel(b), npt.NDArray[np.bool])
assert_type(np.ravel(f4), npt.NDArray[np.float32])
assert_type(np.ravel(f), npt.NDArray[Any])
assert_type(np.ravel(AR_b), npt.NDArray[np.bool])
assert_type(np.ravel(AR_f4), npt.NDArray[np.float32])
assert_type(np.ravel(b), np.ndarray[tuple[int], np.dtype[np.bool]])
assert_type(np.ravel(f4), np.ndarray[tuple[int], np.dtype[np.float32]])
assert_type(np.ravel(f), np.ndarray[tuple[int], np.dtype[np.float64 | np.int_ | np.bool]])
assert_type(np.ravel(AR_b), np.ndarray[tuple[int], np.dtype[np.bool]])
assert_type(np.ravel(AR_f4), np.ndarray[tuple[int], np.dtype[np.float32]])

assert_type(np.nonzero(b), NoReturn)
assert_type(np.nonzero(f4), NoReturn)
Expand Down
8 changes: 4 additions & 4 deletions numpy/typing/tests/data/reveal/ndarray_misc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,11 @@ assert_type(AR_f8.trace(out=B), SubClass)
assert_type(AR_f8.item(), float)
assert_type(AR_U.item(), str)

assert_type(AR_f8.ravel(), npt.NDArray[np.float64])
assert_type(AR_U.ravel(), npt.NDArray[np.str_])
assert_type(AR_f8.ravel(), np.ndarray[tuple[int], np.dtype[np.float64]])
assert_type(AR_U.ravel(), np.ndarray[tuple[int], np.dtype[np.str_]])

assert_type(AR_f8.flatten(), npt.NDArray[np.float64])
assert_type(AR_U.flatten(), npt.NDArray[np.str_])
assert_type(AR_f8.flatten(), np.ndarray[tuple[int], np.dtype[np.float64]])
assert_type(AR_U.flatten(), np.ndarray[tuple[int], np.dtype[np.str_]])

assert_type(AR_f8.reshape(1), npt.NDArray[np.float64])
assert_type(AR_U.reshape(1), npt.NDArray[np.str_])
Expand Down
8 changes: 4 additions & 4 deletions numpy/typing/tests/data/reveal/ndarray_shape_manipulation.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ assert_type(nd.transpose((1, 0)), npt.NDArray[np.int64])
assert_type(nd.swapaxes(0, 1), npt.NDArray[np.int64])

# flatten
assert_type(nd.flatten(), npt.NDArray[np.int64])
assert_type(nd.flatten("C"), npt.NDArray[np.int64])
assert_type(nd.flatten(), np.ndarray[tuple[int], np.dtype[np.int64]])
assert_type(nd.flatten("C"), np.ndarray[tuple[int], np.dtype[np.int64]])

# ravel
assert_type(nd.ravel(), npt.NDArray[np.int64])
assert_type(nd.ravel("C"), npt.NDArray[np.int64])
assert_type(nd.ravel(), np.ndarray[tuple[int], np.dtype[np.int64]])
assert_type(nd.ravel("C"), np.ndarray[tuple[int], np.dtype[np.int64]])

# squeeze
assert_type(nd.squeeze(), npt.NDArray[np.int64])
Expand Down
30 changes: 15 additions & 15 deletions numpy/typing/tests/data/reveal/scalars.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,21 @@ assert_type(c16.tolist(), complex)
assert_type(U.tolist(), str)
assert_type(S.tolist(), bytes)

assert_type(b.ravel(), npt.NDArray[np.bool])
assert_type(i8.ravel(), npt.NDArray[np.int64])
assert_type(u8.ravel(), npt.NDArray[np.uint64])
assert_type(f8.ravel(), npt.NDArray[np.float64])
assert_type(c16.ravel(), npt.NDArray[np.complex128])
assert_type(U.ravel(), npt.NDArray[np.str_])
assert_type(S.ravel(), npt.NDArray[np.bytes_])

assert_type(b.flatten(), npt.NDArray[np.bool])
assert_type(i8.flatten(), npt.NDArray[np.int64])
assert_type(u8.flatten(), npt.NDArray[np.uint64])
assert_type(f8.flatten(), npt.NDArray[np.float64])
assert_type(c16.flatten(), npt.NDArray[np.complex128])
assert_type(U.flatten(), npt.NDArray[np.str_])
assert_type(S.flatten(), npt.NDArray[np.bytes_])
assert_type(b.ravel(), np.ndarray[tuple[int], np.dtype[np.bool]])
assert_type(i8.ravel(), np.ndarray[tuple[int], np.dtype[np.int64]])
assert_type(u8.ravel(), np.ndarray[tuple[int], np.dtype[np.uint64]])
assert_type(f8.ravel(), np.ndarray[tuple[int], np.dtype[np.float64]])
assert_type(c16.ravel(), np.ndarray[tuple[int], np.dtype[np.complex128]])
assert_type(U.ravel(), np.ndarray[tuple[int], np.dtype[np.str_]])
assert_type(S.ravel(), np.ndarray[tuple[int], np.dtype[np.bytes_]])

assert_type(b.flatten(), np.ndarray[tuple[int], np.dtype[np.bool]])
assert_type(i8.flatten(), np.ndarray[tuple[int], np.dtype[np.int64]])
assert_type(u8.flatten(), np.ndarray[tuple[int], np.dtype[np.uint64]])
assert_type(f8.flatten(), np.ndarray[tuple[int], np.dtype[np.float64]])
assert_type(c16.flatten(), np.ndarray[tuple[int], np.dtype[np.complex128]])
assert_type(U.flatten(), np.ndarray[tuple[int], np.dtype[np.str_]])
assert_type(S.flatten(), np.ndarray[tuple[int], np.dtype[np.bytes_]])

assert_type(b.reshape(1), npt.NDArray[np.bool])
assert_type(i8.reshape(1), npt.NDArray[np.int64])
Expand Down

0 comments on commit 547deac

Please sign in to comment.