10000 Merge pull request #27019 from guan404ming/array-builtin-type · numpy/numpy@8a7e0e0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8a7e0e0

Browse files
authored
Merge pull request #27019 from guan404ming/array-builtin-type
TYP: improved ``numpy.array`` type hints for array-like input
2 parents 36b7ff9 + f5e479a commit 8a7e0e0

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

numpy/_core/multiarray.pyi

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,12 @@ from numpy._typing import (
8282
_T_co = TypeVar("_T_co", covariant=True)
8383
_T_contra = TypeVar("_T_contra", contravariant=True)
8484
_SCT = TypeVar("_SCT", bound= 10000 generic)
85-
_ArrayType = TypeVar("_ArrayType", bound=NDArray[Any])
85+
_ArrayType = TypeVar("_ArrayType", bound=ndarray[Any, Any])
86+
_ArrayType_co = TypeVar(
87+
"_ArrayType_co",
88+
bound=ndarray[Any, Any],
89+
covariant=True,
90+
)
8691

8792
# Valid time units
8893
_UnitKind = L[
@@ -113,6 +118,9 @@ class _SupportsLenAndGetItem(Protocol[_T_contra, _T_co]):
113118
def __len__(self) -> int: ...
114119
def __getitem__(self, key: _T_contra, /) -> _T_co: ...
115120

121+
class _SupportsArray(Protocol[_ArrayType_co]):
122+
def __array__(self, /) -> _ArrayType_co: ...
123+
116124
__all__: list[str]
117125

118126
ALLOW_THREADS: Final[int] # 0 or 1 (system-specific)
@@ -188,6 +196,17 @@ def array(
188196
like: None | _SupportsArrayFunc = ...,
189197
) -> _ArrayType: ...
190198
@overload
199+
def array(
200+
object: _SupportsArray[_ArrayType],
201+
dtype: None = ...,
202+
*,
203+
copy: None | bool | _CopyMode = ...,
204+
order: _OrderKACF = ...,
205+
subok: L[True],
206+
ndmin: L[0] = ...,
207+
like: None | _SupportsArrayFunc = ...,
208+
) -> _ArrayType: ...
209+
@overload
191210
def array(
192211
object: _ArrayLike[_SCT],
193212
dtype: None = ...,

numpy/typing/tests/data/reveal/array_constructors.pyi

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ i8: np.int64
2020
A: npt.NDArray[np.float64]
2121
B: SubClass[np.float64]
2222
C: list[int]
23+
D: SubClass[np.float64 | np.int64]
2324

2425
def func(i: int, j: int, **kwargs: Any) -> SubClass[np.float64]: ...
2526

@@ -31,12 +32,16 @@ assert_type(np.empty_like(A, dtype='c16'), npt.NDArray[Any])
3132

3233
assert_type(np.array(A), npt.NDArray[np.float64])
3334
assert_type(np.array(B), npt.NDArray[np.float64])
34-
assert_type(np.array(B, subok=True), SubClass[np.float64])
3535
assert_type(np.array([1, 1.0]), npt.NDArray[Any])
3636
assert_type(np.array(deque([1, 2, 3])), npt.NDArray[Any])
3737
assert_type(np.array(A, dtype=np.int64), npt.NDArray[np.int64])
3838
assert_type(np.array(A, dtype='c16'), npt.NDArray[Any])
3939
assert_type(np.array(A, like=A), npt.NDArray[np.float64])
40+
assert_type(np.array(A, subok=True), npt.NDArray[np.float64])
41+
assert_type(np.array(B, subok=True), SubClass[np.float64])
42+
assert_type(np.array(B, subok=True, ndmin=0), SubClass[np.float64])
43+
assert_type(np.array(B, subok=True, ndmin=1), SubClass[np.float64])
44+
assert_type(np.array(D), npt.NDArray[np.float64 | np.int64])
4045

4146
assert_type(np.zeros([1, 5, 6]), npt.NDArray[np.float64])
4247
assert_type(np.zeros([1, 5, 6], dtype=np.int64), npt.NDArray[np.int64])

0 commit comments

Comments
 (0)
0