8000 TYP: Shape-typed array constructors: ``numpy.{empty, zeros, ones, ful… · numpy/numpy@ef8555d · GitHub
[go: up one dir, main page]

Skip to content

Commit ef8555d

Browse files
committed
TYP: Shape-typed array constructors: numpy.{empty, zeros, ones, full}
1 parent fc8a569 commit ef8555d

File tree

3 files changed

+249
-109
lines changed

3 files changed

+249
-109
lines changed

numpy/_core/multiarray.pyi

Lines changed: 127 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
# TODO: Sort out any and all missing functions in this namespace
2-
import builtins
32
import os
43
import datetime as dt
54
from collections.abc import Sequence, Callable, Iterable
65
from typing import (
76
Literal as L,
87
Any,
98
overload,
9+
TypeAlias,
1010
TypeVar,
11+
TypedDict,
1112
SupportsIndex,
1213
final,
1314
Final,
1415
Protocol,
1516
ClassVar,
17+
type_check_only,
1618
)
19+
from typing_extensions import Unpack
1720

1821
import numpy as np
19-
from numpy import (
22+
from numpy import ( # type: ignore[attr-defined]
2023
# Re-exports
2124
busdaycalendar as busdaycalendar,
2225
broadcast as broadcast,
@@ -56,6 +59,7 @@ from numpy._typing import (
5659
# DTypes
5760
DTypeLike,
5861
_DTypeLike,
62+
_SupportsDType,
5963

6064
# Arrays
6165
NDArray,
@@ -82,12 +86,17 @@ from numpy._typing import (
8286
_T_co = TypeVar("_T_co", covariant=True)
8387
_T_contra = TypeVar("_T_contra", contravariant=True)
8488
_SCT = TypeVar("_SCT", bound=generic)
89+
_DType = TypeVar("_DType", bound=np.dtype[Any])
8590
_ArrayType = TypeVar("_ArrayType", bound=ndarray[Any, Any])
8691
_ArrayType_co = TypeVar(
8792
"_ArrayType_co",
8893
bound=ndarray[Any, Any],
8994
covariant=True,
9095
)
96+
_SizeType = TypeVar("_SizeType", bound=int)
97+
_ShapeType = TypeVar("_ShapeType", bound=tuple[int, ...])
98+
99+
_Array: TypeAlias = ndarray[_ShapeType, dtype[_SCT]]
91100

92101
# Valid time units
93102
_UnitKind = L[
@@ -121,6 +130,119 @@ class _SupportsLenAndGetItem(Protocol[_T_contra, _T_co]):
121130
class _SupportsArray(Protocol[_ArrayType_co]):
122131
def __array__(self, /) -> _ArrayType_co: ...
123132

133+
@type_check_only
134+
class _KwargsEmptyLike(TypedDict, total=False):
135+
device: None | L["cpu"]
136+
137+
@type_check_only
138+
class _KwargsEmpty(_KwargsEmptyLike, total=False):
139+
like: None | _SupportsArrayFunc
140+
141+
@type_check_only
142+
class _ConstructorEmpty(Protocol):
143+
# 1-D shape
144+
@overload
145+
def __call__(
146+
self, /,
147+
shape: _SizeType,
148+
dtype: None = ...,
149+
order: _OrderCF = ...,
150+
**kwargs: Unpack[_KwargsEmpty],
151+
) -> _Array[tuple[_SizeType], float64]: ...
152+
@overload
153+
def __call__(
154+
self, /,
155+
shape: _SizeType,
156+
dtype: _DType | _SupportsDType[_DType],
157+
order: _OrderCF = ...,
158+
**kwargs: Unpack[_KwargsEmpty],
159+
) -> ndarray[tuple[_SizeType], _DType]: ...
160+
@overload
161+
def __call__(
162+
self, /,
163+
shape: _SizeType,
164+
dtype: type[_SCT],
165+
order: _OrderCF = ...,
166+
**kwargs: Unpack[_KwargsEmpty],
167+
) -> _Array[tuple[_SizeType], _SCT]: ...
168+
@overload
169+
def __call__(
170+
self, /,
171+
shape: _SizeType,
172+
dtype: DTypeLike,
173+
order: _OrderCF = ...,
174+
**kwargs: Unpack[_KwargsEmpty],
175+
) -> _Array[tuple[_SizeType], Any]: ...
176+
177+
# known shape
178+
@overload
179+
def __call__(
180+
self, /,
181+
shape: _ShapeType,
182+
dtype: None = ...,
183+
order: _OrderCF = ...,
184+
**kwargs: Unpack[_KwargsEmpty],
185+
) -> _Array[_ShapeType, float64]: ...
186+
@overload
187+
def __call__(
188+
self, /,
189+
shape: _ShapeType,
190+
dtype: _DType | _SupportsDType[_DType],
191+
order: _OrderCF = ...,
192+
**kwargs: Unpack[_KwargsEmpty],
193+
) -> ndarray[_ShapeType, _DType]: ...
194+
@overload
195+
def __call__(
196+
self, /,
197+
shape: _ShapeType,
198+
dtype: type[_SCT],
199+
order: _OrderCF = ...,
200+
**kwargs: Unpack[_KwargsEmpty],
201+
) -> _Array[_ShapeType, _SCT]: ...
202+
@overload
203+
def __call__(
204+
self, /,
205+
shape: _ShapeType,
206+
dtype: DTypeLike,
207+
order: _OrderCF = ...,
208+
**kwargs: Unpack[_KwargsEmpty],
209+
) -> _Array[_ShapeType, Any]: ...
210+
211+
# unknown shape
212+
@overload
213+
def __call__(
214+
self, /,
215+
shape: _ShapeLike,
216+
dtype: None = ...,
217+
order: _OrderCF = ...,
218+
**kwargs: Unpack[_KwargsEmpty],
219+
) -> NDArray[float64]: ...
220+
@overload
221+
def __call__(
222+
self, /,
223+
shape: _ShapeLike,
224+
dtype: _DType | _SupportsDType[_DType],
225+
order: _OrderCF = ...,
226+
**kwargs: Unpack[_KwargsEmpty],
227+
) -> ndarray[Any, _DType]: ...
228+
@overload
229+
def __call__(
230+
self, /,
231+
shape: _ShapeLike,
232+
dtype: type[_SCT],
233+
order: _OrderCF = ...,
234+
**kwargs: Unpack[_KwargsEmpty],
235+
) -> NDArray[_SCT]: ...
236+
@overload
237+
def __call__(
238+
self, /,
239+
shape: _ShapeLike,
240+
dtype: DTypeLike,
241+
order: _OrderCF = ...,
242+
**kwargs: Unpack[_KwargsEmpty],
243+
) -> NDArray[Any]: ...
244+
245+
124246
__all__: list[str]
125247

126248
ALLOW_THREADS: Final[int] # 0 or 1 (system-specific)
@@ -133,6 +255,9 @@ MAY_SHARE_BOUNDS: L[0]
133255
MAY_SHARE_EXACT: L[-1]
134256
tracemalloc_domain: L[389047]
135257

258+
zeros: Final[_ConstructorEmpty]
259+
empty: Final[_ConstructorEmpty]
260+
136261
@overload
137262
def empty_like(
138263
prototype: _ArrayType,
@@ -251,62 +376,6 @@ def array(
251376
like: None | _SupportsArrayFunc = ...,
252377
) -> NDArray[Any]: ...
253378

254-
@overload
255-
def zeros(
256-
shape: _ShapeLike,
257-
dtype: None = ...,
258-
order: _OrderCF = ...,
259-
*,
260-
device: None | L["cpu"] = ...,
261-
like: None | _SupportsArrayFunc = ...,
262-
) -> NDArray[float64]: ...
263-
@overload
264-
def zeros(
265-
shape: _ShapeLike,
266-
dtype: _DTypeLike[_SCT],
267-
order: _OrderCF = ...,
268-
*,
269-
device: None | L["cpu"] = ...,
270-
like: None | _SupportsArrayFunc = ...,
271-
) -> NDArray[_SCT]: ...
272-
@overload
273-
def zeros(
274-
shape: _ShapeLike,
275-
dtype: DTypeLike,
276-
order: _OrderCF = ...,
277-
*,
278-
device: None | L["cpu"] = ...,
279-
like: None | _SupportsArrayFunc = ...,
280-
) -> NDArray[Any]: ...
281-
282-
@overload
283-
def empty(
284-
shape: _ShapeLike,
285-
dtype: None = ...,
286-
order: _OrderCF = ...,
287-
*,
288-
device: None | L["cpu"] = ...,
289-
like: None | _SupportsArrayFunc = ...,
290-
) -> NDArray[float64]: ...
291-
@overload
292-
def empty(
293-
shape: _ShapeLike,
294-
dtype: _DTypeLike[_SCT],
295-
order: _OrderCF = ...,
296-
*,
297-
device: None | L["cpu"] = ...,
298-
like: None | _SupportsArrayFunc = ...,
299-
) -> NDArray[_SCT]: ...
300-
@overload
301-
def empty(
302-
shape: _ShapeLike,
303-
dtype: DTypeLike,
304-
order: _OrderCF = ...,
305-
*,
306-
device: None | L["cpu"] = ...,
307-
like: None | _SupportsArrayFunc = ...,
308-
) -> NDArray[Any]: ...
309-
310379
@overload
311380
def unravel_index( # type: ignore[misc]
312381
indices: _IntLike_co,

0 commit comments

Comments
 (0)
0