Skip to content

Commit

Permalink
TYP: Fix falsely rejected value types in ndarray.__setitem__
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham committed Dec 10, 2024
1 parent 7650730 commit 326f676
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 15 deletions.
67 changes: 55 additions & 12 deletions numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,8 @@ _Float64_co: TypeAlias = float | floating[_64Bit] | float32 | float16 | integer[
_Complex64_co: TypeAlias = number[_32Bit] | number[_16Bit] | number[_8Bit] | builtins.bool | np.bool
_Complex128_co: TypeAlias = complex | number[_64Bit] | _Complex64_co

_ArrayIndexLike: TypeAlias = SupportsIndex | slice | EllipsisType | _ArrayLikeInt_co | None
_ToIndex: TypeAlias = SupportsIndex | slice | EllipsisType | _ArrayLikeInt_co | None
_ToIndices: TypeAlias = _ToIndex | tuple[_ToIndex, ...]

_UnsignedIntegerCType: TypeAlias = type[
ct.c_uint8 | ct.c_uint16 | ct.c_uint32 | ct.c_uint64
Expand Down Expand Up @@ -982,6 +983,8 @@ if sys.version_info >= (3, 11):
_ConvertibleToComplex: TypeAlias = SupportsComplex | SupportsFloat | SupportsIndex | _CharLike_co
else:
_ConvertibleToComplex: TypeAlias = complex | SupportsComplex | SupportsFloat | SupportsIndex | _CharLike_co
_ConvertibleToTD64: TypeAlias = dt.timedelta | int | _CharLike_co | character | number | timedelta64 | np.bool | None
_ConvertibleToDT64: TypeAlias = dt.date | int | _CharLike_co | character | number | datetime64 | np.bool | None

_NDIterFlagsKind: TypeAlias = L[
"buffered",
Expand Down Expand Up @@ -1070,7 +1073,7 @@ class _HasShapeAndSupportsItem(_HasShape[_ShapeT_co], _SupportsItem[_T_co], Prot

# matches any `x` on `x.type.item() -> _T_co`, e.g. `dtype[np.int8]` gives `_T_co: int`
@type_check_only
class _HashTypeWithItem(Protocol[_T_co]):
class _HasTypeWithItem(Protocol[_T_co]):
@property
def type(self, /) -> type[_SupportsItem[_T_co]]: ...

Expand All @@ -1082,7 +1085,7 @@ class _HasShapeAndDTypeWithItem(Protocol[_ShapeT_co, _T_co]):
@property
def shape(self, /) -> _ShapeT_co: ...
@property
def dtype(self, /) -> _HashTypeWithItem[_T_co]: ...
def dtype(self, /) -> _HasTypeWithItem[_T_co]: ...

@type_check_only
class _HasRealAndImag(Protocol[_RealT_co, _ImagT_co]):
Expand Down Expand Up @@ -1112,6 +1115,7 @@ class _HasDateAttributes(Protocol):
@property
def year(self) -> int: ...


### Mixins (for internal use only)

@type_check_only
Expand Down Expand Up @@ -2006,7 +2010,6 @@ class _ArrayOrScalarCommon:
correction: float = ...,
) -> _ArrayT: ...


class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DType_co]):
__hash__: ClassVar[None] # type: ignore[assignment] # pyright: ignore[reportIncompatibleMethodOverride]
@property
Expand Down Expand Up @@ -2082,16 +2085,56 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DType_co]):
@overload
def __getitem__(self, key: SupportsIndex | tuple[SupportsIndex, ...], /) -> Any: ...
@overload
def __getitem__(self, key: _ArrayIndexLike | tuple[_ArrayIndexLike, ...], /) -> ndarray[_Shape, _DType_co]: ...
def __getitem__(self, key: _ToIndices, /) -> ndarray[_Shape, _DType_co]: ...
@overload
def __getitem__(self: NDArray[void], key: str, /) -> ndarray[_ShapeT_co, np.dtype[Any]]: ...
@overload
def __getitem__(self: NDArray[void], key: list[str], /) -> ndarray[_ShapeT_co, _dtype[void]]: ...

@overload
def __setitem__(self: NDArray[void], key: str | list[str], value: ArrayLike, /) -> None: ...
@overload
def __setitem__(self, key: _ArrayIndexLike | tuple[_ArrayIndexLike, ...], value: ArrayLike, /) -> None: ...
@overload # flexible | object_ | bool
def __setitem__(
self: ndarray[Any, dtype[flexible | object_ | np.bool] | dtypes.StringDType],
key: _ToIndices,
value: object,
/,
) -> None: ...
@overload # integer
def __setitem__(
self: NDArray[integer],
key: _ToIndices,
value: _ConvertibleToInt | _NestedSequence[_ConvertibleToInt] | _ArrayLikeInt_co,
/,
) -> None: ...
@overload # floating
def __setitem__(
self: NDArray[floating],
key: _ToIndices,
value: _ConvertibleToFloat | _NestedSequence[_ConvertibleToFloat | None] | _ArrayLikeFloat_co | None,
/,
) -> None: ...
@overload # complexfloating
def __setitem__(
self: NDArray[complexfloating],
key: _ToIndices,
value: _ConvertibleToComplex | _NestedSequence[_ConvertibleToComplex | None] | _ArrayLikeNumber_co | None,
/,
) -> None: ...
@overload # timedelta64
def __setitem__(
self: NDArray[timedelta64],
key: _ToIndices,
value: _ConvertibleToTD64 | _NestedSequence[_ConvertibleToTD64],
/,
) -> None: ...
@overload # datetime64
def __setitem__(
self: NDArray[datetime64],
key: _ToIndices,
value: _ConvertibleToDT64 | _NestedSequence[_ConvertibleToDT64],
/,
) -> None: ...
@overload # catch-all
def __setitem__(self, key: _ToIndices, value: ArrayLike, /) -> None: ...

@property
def ctypes(self) -> _ctypes[int]: ...
Expand Down Expand Up @@ -4122,16 +4165,16 @@ class timedelta64(_IntegralMixin, generic[_TD64ItemT_co], Generic[_TD64ItemT_co]
@overload
def __init__(self: timedelta64[int], value: dt.timedelta, format: _TimeUnitSpec[_IntTimeUnit], /) -> None: ...
@overload
def __init__(self: timedelta64[int], value: int, format: _TimeUnitSpec[_IntTD64Unit] = ..., /) -> None: ...
def __init__(self: timedelta64[int], value: _IntLike_co, format: _TimeUnitSpec[_IntTD64Unit] = ..., /) -> None: ...
@overload
def __init__(
self: timedelta64[dt.timedelta],
value: dt.timedelta | int,
value: dt.timedelta | _IntLike_co,
format: _TimeUnitSpec[_NativeTD64Unit] = ...,
/,
) -> None: ...
@overload
def __init__(self, value: int | bytes | str | dt.timedelta | None, format: _TimeUnitSpec = ..., /) -> None: ...
def __init__(self, value: _ConvertibleToTD64, format: _TimeUnitSpec = ..., /) -> None: ...

# NOTE: Only a limited number of units support conversion
# to builtin scalar types: `Y`, `M`, `ns`, `ps`, `fs`, `as`
Expand Down
10 changes: 7 additions & 3 deletions numpy/typing/tests/data/pass/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,21 @@ def iterable_func(x: Iterable[object]) -> Iterable[object]:

array_2d = np.ones((3, 3))
array_2d[:2, :2]
array_2d[..., 0]
array_2d[:2, :2] = 0
array_2d[..., 0]
array_2d[..., 0] = 2
array_2d[-1, -1] = None

array_obj = np.zeros(1, dtype=np.object_)
array_obj[0] = slice(None)

# Other special methods
len(array)
str(array)
array_scalar = np.array(1)
int(array_scalar)
float(array_scalar)
# currently does not work due to https://github.com/python/typeshed/issues/1904
# complex(array_scalar)
complex(array_scalar)
bytes(array_scalar)
operator.index(array_scalar)
bool(array_scalar)
Expand Down

0 comments on commit 326f676

Please sign in to comment.