8000 Merge pull request #28644 from jorenham/typing/fix-27944 · numpy/numpy@a6ebba8 · GitHub
[go: up one dir, main page]

Skip to content
Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit a6ebba8

Browse files
authored
Merge pull request #28644 from jorenham/typing/fix-27944
TYP: fix `ndarray.tolist()` and `.item()` for unknown dtype
2 parents 5825d67 + 3e30498 commit a6ebba8

File tree

2 files changed

+23
-43
lines changed

2 files changed

+23
-43
lines changed

numpy/__init__.pyi

Lines changed: 14 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,35 +1092,11 @@ class _SupportsItem(Protocol[_T_co]):
10921092
class _SupportsDLPack(Protocol[_T_contra]):
10931093
def __dlpack__(self, /, *, stream: _T_contra | None = None) -> CapsuleType: ...
10941094

1095-
@type_check_only
1096-
class _HasShape(Protocol[_ShapeT_co]):
1097-
@property
1098-
def shape(self, /) -> _ShapeT_co: ...
1099-
11001095
@type_check_only
11011096
class _HasDType(Protocol[_T_co]):
11021097
@property
11031098
def dtype(self, /) -> _T_co: ...
11041099

1105-
@type_check_only
1106-
class _HasShapeAndSupportsItem(_HasShape[_ShapeT_co], _SupportsItem[_T_co], Protocol[_ShapeT_co, _T_co]): ...
1107-
1108-
# matches any `x` on `x.type.item() -> _T_co`, e.g. `dtype[np.int8]` gives `_T_co: int`
1109-
@type_check_only
1110-
class _HasTypeWithItem(Protocol[_T_co]):
1111-
@property
1112-
def type(self, /) -> type[_SupportsItem[_T_co]]: ...
1113-
1114-
# matches any `x` on `x.shape: _ShapeT_co` and `x.dtype.type.item() -> _T_co`,
1115-
# useful for capturing the item-type (`_T_co`) of the scalar-type of an array with
1116-
# specific shape (`_ShapeT_co`).
1117-
@type_check_only
1118-
class _HasShapeAndDTypeWithItem(Protocol[_ShapeT_co, _T_co]):
1119-
@property
1120-
def shape(self, /) -> _ShapeT_co: ...
1121-
@property
1122-
def dtype(self, /) -> _HasTypeWithItem[_T_co]: ...
1123-
11241100
@type_check_only
11251101
class _HasRealAndImag(Protocol[_RealT_co, _ImagT_co]):
11261102
@property
@@ -2199,29 +2175,26 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DType_co]):
21992175
@property
22002176
def flat(self) -> flatiter[Self]: ...
22012177

2202-
@overload # special casing for `StringDType`, which has no scalar type
2203-
def item(self: ndarray[Any, dtypes.StringDType], /) -> str: ...
2204-
@overload
2205-
def item(self: ndarray[Any, dtypes.StringDType], arg0: SupportsIndex | tuple[SupportsIndex, ...] = ..., /) -> str: ...
2206-
@overload
2207-
def item(self: ndarray[Any, dtypes.StringDType], /, *args: SupportsIndex) -> str: ...
22082178
@overload # use the same output type as that of the underlying `generic`
2209-
def item(self: _HasShapeAndDTypeWithItem[Any, _T], /) -> _T: ...
2210-
@overload
2211-
def item(self: _HasShapeAndDTypeWithItem[Any, _T], arg0: SupportsIndex | tuple[SupportsIndex, ...] = ..., /) -> _T: ...
2212-
@overload
2213-
def item(self: _HasShapeAndDTypeWithItem[Any, _T], /, *args: SupportsIndex) -> _T: ...
2179+
def item(self: NDArray[generic[_T]], i0: SupportsIndex | tuple[SupportsIndex, ...] = ..., /, *args: SupportsIndex) -> _T: ...
2180+
@overload # special casing for `StringDType`, which has no scalar type
2181+
def item(
2182+
self: ndarray[Any, dtypes.StringDType],
2183+
arg0: SupportsIndex | tuple[SupportsIndex, ...] = ...,
2184+
/,
2185+
*args: SupportsIndex,
2186+
) -> str: ...
22142187

22152188
@overload
2216-
def tolist(self: _HasShapeAndSupportsItem[tuple[()], _T], /) -> _T: ...
2189+
def tolist(self: ndarray[tuple[()], dtype[generic[_T]]], /) -> _T: ...
22172190
@overload
2218-
def tolist(self: _HasShapeAndSupportsItem[tuple[int], _T], /) -> list[_T]: ...
2191+
def tolist(self: ndarray[tuple[int], dtype[generic[_T]]], /) -> list[_T]: ...
22192192
@overload
2220-
def tolist(self: _HasShapeAndSupportsItem[tuple[int, int], _T], /) -> list[list[_T]]: ...
2193+
def tolist(self: ndarray[tuple[int, int], dtype[generic[_T]]], /) -> list[list[_T]]: ...
22212194
@overload
2222-
def tolist(self: _HasShapeAndSupportsItem[tuple[int, int, int], _T], /) -> list[list[list[_T]]]: ...
2195+
def tolist(self: ndarray[tuple[int, int, int], dtype[generic[_T]]], /) -> list[list[list[_T]]]: ...
22232196
@overload
2224-
def tolist(self: _HasShapeAndSupportsItem[Any, _T], /) -> _T | list[_T] | list[list[_T]] | list[list[list[Any]]]: ...
2197+
def tolist(self, /) -> Any: ...
22252198

22262199
@overload
22272200
def resize(self, new_shape: _ShapeLike, /, *, refcheck: builtins.bool = ...) -> None: ...
@@ -5364,7 +5337,7 @@ class matrix(ndarray[_2DShapeT_co, _DType_co]):
53645337
def ptp(self, axis: None | _ShapeLike = ..., out: _ArrayT = ...) -> _ArrayT: ...
53655338

53665339
def squeeze(self, axis: None | _ShapeLike = ...) -> matrix[_2D, _DType_co]: ...
5367-
def tolist(self: _SupportsItem[_T]) -> list[list[_T]]: ...
5340+
def tolist(self: matrix[Any, dtype[generic[_T]]]) -> list[list[_T]]: ... # pyright: ignore[reportIncompatibleMethodOverride]
53685341
def ravel(self, /, order: _OrderKACF = "C") -> matrix[tuple[L[1], int], _DType_co]: ... # pyright: ignore[reportIncompatibleMethodOverride]
53695342
def flatten(self, /, order: _OrderKACF = "C") -> matrix[tuple[L[1], int], _DType_co]: ... # pyright: ignore[reportIncompatibleMethodOverride]
53705343

numpy/typing/tests/data/reveal/ndarray_conversion.pyi

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,15 @@ assert_type(b1_0d.tolist(), bool)
3030
assert_type(u2_1d.tolist(), list[int])
3131
assert_type(i4_2d.tolist(), list[list[int]])
3232
assert_type(f8_3d.tolist(), list[list[list[float]]])
33-
assert_type(cG_4d.tolist(), complex | list[complex] | list[list[complex]] | list[list[list[Any]]])
34-
assert_type(i0_nd.tolist(), int | list[int] | list[list[int]] | list[list[list[Any]]])
33+
assert_type(cG_4d.tolist(), Any)
34+
assert_type(i0_nd.tolist(), Any)
35+
36+
# regression tests for numpy/numpy#27944
37+
any_dtype: np.ndarray[Any, Any]
38+
any_sctype: np.ndarray[Any, Any]
39+
assert_type(any_dtype.tolist(), Any)
40+
assert_type(any_sctype.tolist(), Any)
41+
3542

3643
# itemset does not return a value
3744
# tobytes is pretty simple

0 commit comments

Comments
 (0)
0