8000 TYP: use a gradual shape-type if unknown · numpy/numpy@12b741e · GitHub
[go: up one dir, main page]

Skip to content

Commit 12b741e

Browse files
committed
TYP: use a gradual shape-type if unknown
1 parent 3ca7ffb commit 12b741e

15 files changed

+136
-165
lines changed

numpy/__init__.pyi

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ from numpy._typing import (
4444
_DTypeLikeVoid,
4545
_VoidDTypeLike,
4646
# Shapes
47+
_AnyShape,
4748
_Shape,
4849
_ShapeLike,
4950
# Scalars
@@ -794,7 +795,7 @@ _RealArrayT = TypeVar("_RealArrayT", bound=NDArray[floating | integer | timedelt
794795
_NumericArrayT = TypeVar("_NumericArrayT", bound=NDArray[number | timedelta64 | object_])
795796

796797
_ShapeT = TypeVar("_ShapeT", bound=_Shape)
797-
_ShapeT_co = TypeVar("_ShapeT_co", bound=_Shape, default=_Shape, covariant=True)
798+
_ShapeT_co = TypeVar("_ShapeT_co", bound=_Shape, default=_AnyShape, covariant=True)
798799
_1DShapeT = TypeVar("_1DShapeT", bound=_1D)
799800
_2DShapeT_co = TypeVar("_2DShapeT_co", bound=_2D, default=_2D, covariant=True)
800801
_1NShapeT = TypeVar("_1NShapeT", bound=tuple[L[1], *tuple[L[1], ...]]) # (1,) | (1, 1) | (1, 1, 1) | ...
@@ -1579,11 +1580,11 @@ class dtype(Generic[_ScalarT_co], metaclass=_DTypeMeta):
15791580
@property
15801581
def num(self) -> _DTypeNum: ...
15811582
@property
1582-
def shape(self) -> tuple[()] | _Shape: ...
1583+
def shape(self) -> _AnyShape: ...
15831584
@property
15841585
def ndim(self) -> int: ...
15851586
@property
1586-
def subdtype(self) -> tuple[dtype, _Shape] | None: ...
1587+
def subdtype(self) -> tuple[dtype, _AnyShape] | None: ...
15871588
def newbyteorder(self, new_order: _ByteOrder = ..., /) -> Self: ...
15881589
@property
15891590
def str(self) -> LiteralString: ...
@@ -1627,9 +1628,9 @@ class flatiter(Generic[_ArrayT_co]):
16271628
@overload
16281629
def __array__(self: flatiter[ndarray[_1DShapeT, Any]], dtype: _DTypeT, /) -> ndarray[_1DShapeT, _DTypeT]: ...
16291630
@overload
1630-
def __array__(self: flatiter[ndarray[_Shape, _DTypeT]], dtype: None = ..., /) -> ndarray[_Shape, _DTypeT]: ...
1631+
def __array__(self: flatiter[ndarray[Any, _DTypeT]], dtype: None = ..., /) -> ndarray[_AnyShape, _DTypeT]: ...
16311632
@overload
1632-
def __array__(self, dtype: _DTypeT, /) -> ndarray[_Shape, _DTypeT]: ...
1633+
def __array__(self, dtype: _DTypeT, /) -> ndarray[_AnyShape, _DTypeT]: ...
16331634

16341635
@type_check_only
16351636
class _ArrayOrScalarCommon:
@@ -2084,11 +2085,11 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
20842085
) -> ndarray[_ShapeT, _DTypeT]: ...
20852086

20862087
@overload
2087-
def __getitem__(self, key: _ArrayInt_co | tuple[_ArrayInt_co, ...], /) -> ndarray[_Shape, _DTypeT_co]: ...
2088+
def __getitem__(self, key: _ArrayInt_co | tuple[_ArrayInt_co, ...], /) -> ndarray[_AnyShape, _DTypeT_co]: ...
20882089
@overload
20892090
def __getitem__(self, key: SupportsIndex | tuple[SupportsIndex, ...], /) -> Any: ...
20902091
@overload
2091-
def __getitem__(self, key: _ToIndices, /) -> ndarray[_Shape, _DTypeT_co]: ...
2092+
def __getitem__(self, key: _ToIndices, /) -> ndarray[_AnyShape, _DTypeT_co]: ...
20922093
@overload
20932094
def __getitem__(self: NDArray[void], key: str, /) -> ndarray[_ShapeT_co, np.dtype]: ...
20942095
@overload
@@ -2166,6 +2167,8 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
21662167
*args: SupportsIndex,
21672168
) -> str: ...
21682169

2170+
@overload # this first overload prevents mypy from over-eagerly selecting `tuple[()]` in case of `_AnyShape`
2171+
def tolist(self: ndarray[tuple[Never], dtype[generic[_T]]], /) -> Any: ...
21692172
@overload
21702173
def tolist(self: ndarray[tuple[()], dtype[generic[_T]]], /) -> _T: ...
21712174
@overload
@@ -2187,13 +2190,13 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
21872190
def squeeze(
21882191
self,
21892192
axis: SupportsIndex | tuple[SupportsIndex, ...] | None = ...,
2190-
) -> ndarray[_Shape, _DTypeT_co]: ...
2193+
) -> ndarray[_AnyShape, _DTypeT_co]: ...
21912194

21922195
def swapaxes(
21932196
self,
21942197
axis1: SupportsIndex,
21952198
axis2: SupportsIndex,
2196-
) -> ndarray[_Shape, _DTypeT_co]: ...
2199+
) -> ndarray[_AnyShape, _DTypeT_co]: ...
21972200

21982201
@overload
21992202
def transpose(self, axes: _ShapeLike | None, /) -> Self: ...
@@ -2320,7 +2323,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
23202323
offset: SupportsIndex = ...,
23212324
axis1: SupportsIndex = ...,
23222325
axis2: SupportsIndex = ...,
2323-
) -> ndarray[_Shape, _DTypeT_co]: ...
2326+
) -> ndarray[_AnyShape, _DTypeT_co]: ...
23242327

23252328
# 1D + 1D returns a scalar;
23262329
# all other with at least 1 non-0D array return an ndarray.
@@ -2396,7 +2399,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
23962399
axis: SupportsIndex | None = ...,
23972400
out: None = ...,
23982401
mode: _ModeKind = ...,
2399-
) -> ndarray[_Shape, _DTypeT_co]: ...
2402+
) -> ndarray[_AnyShape, _DTypeT_co]: ...
24002403
@overload
24012404
def take(
24022405
self,
@@ -2417,7 +2420,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
24172420
self,
24182421
repeats: _ArrayLikeInt_co,
24192422
axis: SupportsIndex,
2420-
) -> ndarray[_Shape, _DTypeT_co]: ...
2423+
) -> ndarray[_AnyShape, _DTypeT_co]: ...
24212424

24222425
def flatten(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], _DTypeT_co]: ...
24232426
def ravel(self, /, order: _OrderKACF = "C") -> ndarray[tuple[int], _DTypeT_co]: ...
@@ -2493,7 +2496,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
24932496
*shape: SupportsIndex,
24942497
order: _OrderACF = "C",
24952498
copy: builtins.bool | None = None,
2496-
) -> ndarray[_Shape, _DTypeT_co]: ...
2499+
) -> ndarray[_AnyShape, _DTypeT_co]: ...
24972500
@overload # (sequence[index])
24982501
def reshape(
24992502
self,
@@ -2502,7 +2505,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
25022505
*,
25032506
order: _OrderACF = "C",
25042507
copy: builtins.bool | None = None,
2505-
) -> ndarray[_Shape, _DTypeT_co]: ...
2508+
) -> ndarray[_AnyShape, _DTypeT_co]: ...
25062509

25072510
@overload
25082511
def astype(
@@ -4941,7 +4944,7 @@ class broadcast:
49414944
@property
49424945
def numiter(self) -> int: ...
49434946
@property
4944-
def shape(self) -> _Shape: ...
4947+
def shape(self) -> _AnyShape: ...
49454948
@property
49464949
def size(self) -> int: ...
49474950
def __next__(self) -> tuple[Any, ...]: ...
@@ -5398,8 +5401,8 @@ class matrix(ndarray[_2DShapeT_co, _DTypeT_co]):
53985401
def A(self) -> ndarray[_2DShapeT_co, _DTypeT_co]: ...
53995402
def getA(self) -> ndarray[_2DShapeT_co, _DTypeT_co]: ...
54005403
@property
5401-
def A1(self) -> ndarray[_Shape, _DTypeT_co]: ...
5402-
def getA1(self) -> ndarray[_Shape, _DTypeT_co]: ...
5404+
def A1(self) -> ndarray[_AnyShape, _DTypeT_co]: ...
5405+
def getA1(self) -> ndarray[_AnyShape, _DTypeT_co]: ...
54035406
@property
54045407
def H(self) -> matrix[_2D, _DTypeT_co]: ...
54055408
def getH(self) -> matrix[_2D, _DTypeT_co]: ...

numpy/_core/defchararray.pyi

Lines changed: 26 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,11 @@
1-
from typing import (
2-
Any,
3-
SupportsIndex,
4-
SupportsInt,
5-
TypeAlias,
6-
overload,
7-
)
8-
from typing import (
9-
Literal as L,
10-
)
1+
from typing import Any, Self, SupportsIndex, SupportsInt, TypeAlias, overload
2+
from typing import Literal as L
113

124
from typing_extensions import TypeVar
135

146
import numpy as np
157
from numpy import (
168
_OrderKACF,
17-
_SupportsArray,
189
_SupportsBuffer,
1910
bytes_,
2011
dtype,
@@ -24,29 +15,13 @@ from numpy import (
2415
str_,
2516
)
2617
from numpy._core.multiarray import compare_chararrays
27-
from numpy._typing import (
28-
NDArray,
29-
_Shape,
30-
_ShapeLike,
31-
)
32-
from numpy._typing import (
33-
_ArrayLikeAnyString_co as UST_co,
34-
)
35-
from numpy._typing import (
36-
_ArrayLikeBool_co as b_co,
37-
)
38-
from numpy._typing import (
39-
_ArrayLikeBytes_co as S_co,
40-
)
41-
from numpy._typing import (
42-
_ArrayLikeInt_co as i_co,
43-
)
44-
from numpy._typing import (
45-
_ArrayLikeStr_co as U_co,
46-
)
47-
from numpy._typing import (
48-
_ArrayLikeString_co as T_co,
49-
)
18+
from numpy._typing import NDArray, _AnyShape, _Shape, _ShapeLike, _SupportsArray
19+
from numpy._typing import _ArrayLikeAnyString_co as UST_co
20+
from numpy._typing import _ArrayLikeBool_co as b_co
21+
from numpy._typing import _ArrayLikeBytes_co as S_co
22+
from numpy._typing import _ArrayLikeInt_co as i_co
23+
from numpy._typing import _ArrayLikeStr_co as U_co
24+
from numpy._typing import _ArrayLikeString_co as T_co
5025

5126
__all__ = [
5227
"equal",
@@ -104,14 +79,15 @@ __all__ = [
10479
"chararray",
10580
]
10681

107-
_ShapeT_co = TypeVar("_ShapeT_co", bound=_Shape, default=_Shape, covariant=True)
82+
_ShapeT_co = TypeVar("_ShapeT_co", bound=_Shape, default=_AnyShape, covariant=True)
10883
_CharacterT = TypeVar("_CharacterT", bound=np.character)
10984
_CharDTypeT_co = TypeVar("_CharDTypeT_co", bound=dtype[np.character], default=dtype, covariant=True)
110-
_CharArray: TypeAlias = chararray[_Shape, dtype[_CharacterT]]
11185

112-
_StringDTypeArray: TypeAlias = np.ndarray[_Shape, np.dtypes.StringDType]
86+
_CharArray: TypeAlias = chararray[_AnyShape, dtype[_CharacterT]]
87+
88+
_StringDTypeArray: TypeAlias = np.ndarray[_AnyShape, np.dtypes.StringDType]
89+
_StringDTypeOrUnicodeArray: TypeAlias = _StringDTypeArray | NDArray[np.str_]
11390
_StringDTypeSupportsArray: TypeAlias = _SupportsArray[np.dtypes.StringDType]
114-
_StringDTypeOrUnicodeArray: TypeAlias = np.ndarray[_Shape, np.dtype[np.str_]] | np.ndarray[_Shape, np.dtypes.StringDType]
11591

11692
class chararray(ndarray[_ShapeT_co, _CharDTypeT_co]):
11793
@overload
@@ -124,7 +100,7 @@ class chararray(ndarray[_ShapeT_co, _CharDTypeT_co]):
124100
offset: SupportsIndex = ...,
125101
strides: _ShapeLike = ...,
126102
order: _OrderKACF = ...,
127-
) -> chararray[_Shape, dtype[bytes_]]: ...
103+
) -> _CharArray[bytes_]: ...
128104
@overload
129105
def __new__(
130106
subtype,
@@ -135,12 +111,12 @@ class chararray(ndarray[_ShapeT_co, _CharDTypeT_co]):
135111
offset: SupportsIndex = ...,
136112
strides: _ShapeLike = ...,
137113
order: _OrderKACF = ...,
138-
) -> chararray[_Shape, dtype[str_]]: ...
114+
) -> _CharArray[str_]: ...
139115

140116
def __array_finalize__(self, obj: object) -> None: ...
141-
def __mul__(self, other: i_co) -> chararray[_Shape, _CharDTypeT_co]: ...
142-
def __rmul__(self, other: i_co) -> chararray[_Shape, _CharDTypeT_co]: ...
143-
def __mod__(self, i: Any) -> chararray[_Shape, _CharDTypeT_co]: ...
117+
def __mul__(self, other: i_co) -> chararray[_AnyShape, _CharDTypeT_co]: ...
118+
def __rmul__(self, other: i_co) -> chararray[_AnyShape, _CharDTypeT_co]: ...
119+
def __mod__(self, i: Any) -> chararray[_AnyShape, _CharDTypeT_co]: ...
144120

145121
@overload
146122
def __eq__(
@@ -288,7 +264,7 @@ class chararray(ndarray[_ShapeT_co, _CharDTypeT_co]):
288264
def expandtabs(
289265
self,
290266
tabsize: i_co = ...,
291-
) -> chararray[_Shape, _CharDTypeT_co]: ...
267+
) -> Self: ...
292268

293269
@overload
294270
def find(
@@ -513,12 +489,12 @@ class chararray(ndarray[_ShapeT_co, _CharDTypeT_co]):
513489
deletechars: S_co | None = ...,
514490
) -> _CharArray[bytes_]: ...
515491

516-
def zfill(self, width: i_co) -> chararray[_Shape, _CharDTypeT_co]: ...
517-
def capitalize(self) -> chararray[_ShapeT_co, _CharDTypeT_co]: ...
518-
def title(self) -> chararray[_ShapeT_co, _CharDTypeT_co]: ...
519-
def swapcase(self) -> chararray[_ShapeT_co, _CharDTypeT_co]: ...
520-
def lower(self) -> chararray[_ShapeT_co, _CharDTypeT_co]: ...
521-
def upper(self) -> chararray[_ShapeT_co, _CharDTypeT_co]: ...
492+
def zfill(self, width: i_co) -> Self: ...
493+
def capitalize(self) -> Self: ...
494+
def title(self) -> Self: ...
495+
def swapcase(self) -> Self: ...
496+
def lower(self) -> Self: ...
497+
def upper(self) -> Self: ...
522498
def isalnum(self) -> ndarray[_ShapeT_co, dtype[np.bool]]: ...
523499
def isalpha(self) -> ndarray[_ShapeT_co, dtype[np.bool]]: ...
524500
def isdigit(self) -> ndarray[_ShapeT_co, dtype[np.bool]]: ...

numpy/_core/fromnumeric.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ from numpy._typing import (
4141
ArrayLike,
4242
DTypeLike,
4343
NDArray,
44+
_AnyShape,
4445
_ArrayLike,
4546
_ArrayLikeBool_co,
4647
_ArrayLikeComplex_co,
@@ -579,7 +580,7 @@ def nonzero(a: _ArrayLike[Any]) -> tuple[NDArray[intp], ...]: ...
579580

580581
# this prevents `Any` from being returned with Pyright
581582
@overload
582-
def shape(a: _SupportsShape[Never]) -> tuple[int, ...]: ...
583+
def shape(a: _SupportsShape[Never]) -> _AnyShape: ...
583584
@overload
584585
def shape(a: _SupportsShape[_ShapeT]) -> _ShapeT: ...
585586
@overload
@@ -594,7 +595,7 @@ def shape(a: _PyArray[_PyArray[_PyScalar]]) -> tuple[int, int]: ...
594595
@overload
595596
def shape(a: memoryview | bytearray) -> tuple[int]: ...
596597
@overload
597-
def shape(a: ArrayLike) -> tuple[int, ...]: ...
598+
def shape(a: ArrayLike) -> _AnyShape: ...
598599

599600
@overload
600601
def compress(

numpy/_core/multiarray.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ from numpy._typing import (
8888
_NestedSequence,
8989
_ScalarLike_co,
9090
# Shapes
91+
_Shape,
9192
_ShapeLike,
9293
_SupportsArrayFunc,
9394
_SupportsDType,
@@ -207,7 +208,7 @@ _IDType = TypeVar("_IDType")
207208
_Nin = TypeVar("_Nin", bound=int)
208209
_Nout = TypeVar("_Nout", bound=int)
209210

210-
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
211+
_ShapeT = TypeVar("_ShapeT", bound=_Shape)
211212
_Array: TypeAlias = ndarray[_ShapeT, dtype[_ScalarT]]
212213
_Array1D: TypeAlias = ndarray[tuple[int], dtype[_ScalarT]]
213214

numpy/_core/numeric.pyi

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import numpy as np
1818
from numpy import (
1919
False_,
2020
True_,
21-
_AnyShapeT,
2221
_OrderCF,
2322
_OrderKACF,
2423
# re-exports
@@ -63,26 +62,17 @@ from numpy._typing import (
6362
_DTypeLike,
6463
_NestedSequence,
6564
_ScalarLike_co,
65+
_Shape,
6666
_ShapeLike,
6767
_SupportsArrayFunc,
6868
_SupportsDType,
6969
)
7070

71-
from .fromnumeric import (
72-
all as all,
73-
)
74-
from .fromnumeric import (
75-
any as any,
76-
)
77-
from .fromnumeric import (
78-
argpartition as argpartition,
79-
)
80-
from .fromnumeric import (
81-
matrix_transpose as matrix_transpose,
82-
)
83-
from .fromnumeric import (
84-
mean as mean,
85-
)
71+
from .fromnumeric import all as all
72+
from .fromnumeric import any as any
73+
from .fromnumeric import argpartition as argpartition
74+
from .fromnumeric import matrix_transpose as matrix_transpose
75+
from .fromnumeric import mean as mean
8676
from .multiarray import (
8777
# other
8878
_Array,
@@ -198,7 +188,16 @@ _T = TypeVar("_T")
198188
_ScalarT = TypeVar("_ScalarT", bound=generic)
199189
_DTypeT = TypeVar("_DTypeT", bound=np.dtype)
200190
_ArrayT = TypeVar("_ArrayT", bound=np.ndarray[Any, Any])
201-
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
191+
_ShapeT = TypeVar("_ShapeT", bound=_Shape)
192+
_AnyShapeT = TypeVar(
193+
"_AnyShapeT",
194+
tuple[()],
195+
tuple[int],
196+
tuple[int, int],
197+
tuple[int, int, int],
198+
tuple[int, int, int, int],
199+
tuple[int, ...],
200+
)
202201

203202
_CorrelateMode: TypeAlias = L["valid", "same", "full"]
204203

0 commit comments

Comments
 (0)
0