8000 Merge pull request #19261 from BvB93/2dim · numpy/numpy@1e2aa70 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1e2aa70

Browse files
authored
Merge pull request #19261 from BvB93/2dim
ENH: Add annotations for `np.lib.twodim_base`
2 parents e0ceabc + f6a022b commit 1e2aa70

File tree

3 files changed

+353
-21
lines changed

3 files changed

+353
-21
lines changed

numpy/lib/twodim_base.pyi

Lines changed: 244 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,255 @@
1-
from typing import List, Optional, Any
1+
from typing import (
2+
Any,
3+
Callable,
4+
List,
5+
Sequence,
6+
overload,
7+
Tuple,
8+
Type,
9+
TypeVar,
10+
Union,
11+
)
212

3-
from numpy import ndarray, _OrderCF
4-
from numpy.typing import ArrayLike, DTypeLike
13+
from numpy import (
14+
ndarray,
15+
dtype,
16+
generic,
17+
number,
18+
bool_,
19+
timedelta64,
20+
datetime64,
21+
int_,
22+
intp,
23+
float64,
24+
signedinteger,
25+
floating,
26+
complexfloating,
27+
object_,
28+
_OrderCF,
29+
)
30+
31+
from numpy.typing import (
32+
DTypeLike,
33+
_SupportsDType,
34+
ArrayLike,
35+
NDArray,
36+
_NestedSequence,
37+
_SupportsArray,
38+
_ArrayLikeInt_co,
39+
_ArrayLikeFloat_co,
40+
_ArrayLikeComplex_co,
41+
_ArrayLikeObject_co,
42+
)
43+
44+
_T = TypeVar("_T")
45+
_SCT = TypeVar("_SCT", bound=generic)
46+
47+
# The returned arrays dtype must be compatible with `np.equal`
48+
_MaskFunc = Callable[
49+
[NDArray[int_], _T],
50+
NDArray[Union[number[Any], bool_, timedelta64, datetime64, object_]],
51+
]
52+
53+
_DTypeLike = Union[
54+
Type[_SCT],
55+
dtype[_SCT],
56+
_SupportsDType[dtype[_SCT]],
57+
]
58+
_ArrayLike = _NestedSequence[_SupportsArray[dtype[_SCT]]]
559

660
__all__: List[str]
761

8-
def fliplr(m): ...
9-
def flipud(m): ...
62+
@overload
63+
def fliplr(m: _ArrayLike[_SCT]) -> NDArray[_SCT]: ...
64+
@overload
65+
def fliplr(m: ArrayLike) -> NDArray[Any]: ...
66+
67+
@overload
68+
def flipud(m: _ArrayLike[_SCT]) -> NDArray[_SCT]: ...
69+
@overload
70+
def flipud(m: ArrayLike) -> NDArray[Any]: ...
1071

72+
@overload
1173
def eye(
1274
N: int,
13-
M: Optional[int] = ...,
75+
M: None | int = ...,
76+
k: int = ...,
77+
dtype: None = ...,
78+
order: _OrderCF = ...,
79+
*,
80+
like: None | ArrayLike = ...,
81+
) -> NDArray[float64]: ...
82+
@overload
83+
def eye(
84+
N: int,
85+
M: None | int = ...,
86+
k: int = ...,
87+
dtype: _DTypeLike[_SCT] = ...,
88+
order: _OrderCF = ...,
89+
*,
90+
like: None | ArrayLike = ...,
91+
) -> NDArray[_SCT]: ...
92+
@overload
93+
def eye(
94+
N: int,
95+
M: None | int = ...,
1496
k: int = ...,
1597
dtype: DTypeLike = ...,
1698
order: _OrderCF = ...,
1799
*,
18-
like: Optional[ArrayLike] = ...
19-
) -> ndarray[Any, Any]: ...
20-
21-
def diag(v, k=...): ...
22-
def diagflat(v, k=...): ...
23-
def tri(N, M=..., k=..., dtype = ..., *, like=...): ...
24-
def tril(m, k=...): ...
25-
def triu(m, k=...): ...
26-
def vander(x, N=..., increasing=...): ...
27-
def histogram2d(x, y, bins=..., range=..., normed=..., weights=..., density=...): ...
28-
def mask_indices(n, mask_func, k=...): ...
29-
def tril_indices(n, k=..., m=...): ...
30-
def tril_indices_from(arr, k=...): ...
31-
def triu_indices(n, k=..., m=...): ...
32-
def triu_indices_from(arr, k=...): ...
100+
like: None | ArrayLike = ...,
101+
) -> NDArray[Any]: ...
102+
103+
@overload
104+
def diag(v: _ArrayLike[_SCT], k: int = ...) -> NDArray[_SCT]: ...
105+
@overload
106+
def diag(v: ArrayLike, k: int = ...) -> NDArray[Any]: ...
107+
108+
@overload
109+
def diagflat(v: _ArrayLike[_SCT], k: int = ...) -> NDArray[_SCT]: ...
110+
@overload
111+
def diagflat(v: ArrayLike, k: int = ...) -> NDArray[Any]: ...
112+
113+
@overload
114+
def tri(
115+
N: int,
116+
M: None | int = ...,
117+
k: int = ...,
118+
dtype: None = ...,
119+
*,
120+
like: None | ArrayLike = ...
121+
) -> NDArray[float64]: ...
122+
@overload
123+
def tri(
124+
N: int,
125+
M: None | int = ...,
126+
k: int = ...,
127+
dtype: _DTypeLike[_SCT] = ...,
128+
*,
129+
like: None | ArrayLike = ...
130+
) -> NDArray[_SCT]: ...
131+
@overload
132+
def tri(
133+
N: int,
134+
M: None | int = ...,
135+
k: int = ...,
136+
dtype: DTypeLike = ...,
137+
*,
138+
like: None | ArrayLike = ...
139+
) -> NDArray[Any]: ...
140+
141+
@overload
142+
def tril(v: _ArrayLike[_SCT], k: int = ...) -> NDArray[_SCT]: ...
143+
@overload
144+
def tril(v: ArrayLike, k: int = ...) -> NDArray[Any]: ...
145+
146+
@overload
147+
def triu(v: _ArrayLike[_SCT], k: int = ...) -> NDArray[_SCT]: ...
148+
@overload
149+
def triu(v: ArrayLike, k: int = ...) -> NDArray[Any]: ...
150+
151+
@overload
152+
def vander( # type: ignore[misc]
153+
x: _ArrayLikeInt_co,
154+
N: None | int = ...,
155+
increasing: bool = ...,
156+
) -> NDArray[signedinteger[Any]]: ...
157+
@overload
158+
def vander( # type: ignore[misc]
159+
x: _ArrayLikeFloat_co,
160+
N: None | int = ...,
161+
increasing: bool = ...,
162+
) -> NDArray[floating[Any]]: ...
163+
@overload
164+
def vander(
165+
x: _ArrayLikeComplex_co,
166+
N: None | int = ...,
167+
increasing: bool = ...,
168+
) -> NDArray[complexfloating[Any, Any]]: ...
169+
@overload
170+
def vander(
171+
x: _ArrayLikeObject_co,
172+
N: None | int = ...,
173+
increasing: bool = ...,
174+
) -> NDArray[object_]: ...
175+
176+
@overload
177+
def histogram2d( # type: ignore[misc]
178+
x: _ArrayLikeFloat_co,
179+
y: _ArrayLikeFloat_co,
180+
bins: int | Sequence[int] = ...,
181+
range: None | _ArrayLikeFloat_co = ...,
182+
normed: None | bool = ...,
183+
weights: None | _ArrayLikeFloat_co = ...,
184+
density: None | bool = ...,
185+
) -> Tuple[
186+
NDArray[float64],
187+
NDArray[floating[Any]],
188+
NDArray[floating[Any]],
189+
]: ...
190+
@overload
191+
def histogram2d(
192+
x: _ArrayLikeComplex_co,
193+
y: _ArrayLikeComplex_co,
194+
bins: int | Sequence[int] = ...,
195+
range: None | _ArrayLikeFloat_co = ...,
196+
normed: None | bool = ...,
197+
weights: None | _ArrayLikeFloat_co = ...,
198+
density: None | bool = ...,
199+
) -> Tuple[
200+
NDArray[float64],
201+
NDArray[complexfloating[Any, Any]],
202+
NDArray[complexfloating[Any, Any]],
203+
]: ...
204+
@overload # TODO: Sort out `bins`
205+
def histogram2d(
206+
x: _ArrayLikeComplex_co,
207+
y: _ArrayLikeComplex_co,
208+
bins: Sequence[_ArrayLikeInt_co],
209+
range: None | _ArrayLikeFloat_co = ...,
210+
normed: None | bool = ...,
211+
weights: None | _ArrayLikeFloat_co = ...,
212+
density: None | bool = ...,
213+
) -> Tuple[
214+
NDArray[float64],
215+
NDArray[Any],
216+
NDArray[Any],
217+
]: ...
218+
219+
# NOTE: we're assuming/demanding here the `mask_func` returns
220+
# an ndarray of shape `(n, n)`; otherwise there is the possibility
221+
# of the output tuple having more or less than 2 elements
222+
@overload
223+
def mask_indices(
224+
n: int,
225+
mask_func: _MaskFunc[int],
226+
k: int = ...,
227+
) -> Tuple[NDArray[intp], NDArray[intp]]: ...
228+
@overload
229+
def mask_indices(
230+
n: int,
231+
mask_func: _MaskFunc[_T],
232+
k: _T,
233+
) -> Tuple[NDArray[intp], NDArray[intp]]: ...
234+
235+
def tril_indices(
236+
n: int,
237+
k: int = ...,
238+
m: None | int = ...,
239+
) -> Tuple[NDArray[int_], NDArray[int_]]: ...
240+
241+
def tril_indices_from(
242+
arr: NDArray[Any],
243+
k: int = ...,
244+
) -> Tuple[NDArray[int_], NDArray[int_]]: ...
245+
246+
def triu_indices(
247+
n: int,
248+
k: int = ...,
249+
m: None | int = ...,
250+
) -> Tuple[NDArray[int_], NDArray[int_]]: ...
251+
252+
def triu_indices_from(
253+
arr: NDArray[Any],
254+
k: int = ...,
255+
) -> Tuple[NDArray[int_], NDArray[int_]]: ...
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from typing import Any, List, TypeVar
2+
3+
import numpy as np
4+
import numpy.typing as npt
5+
6+
7+
def func1(ar: npt.NDArray[Any], a: int) -> npt.NDArray[np.str_]:
8+
pass
9+
10+
11+
def func2(ar: npt.NDArray[Any], a: float) -> float:
12+
pass
13+
14+
15+
AR_b: npt.NDArray[np.bool_]
16+
AR_m: npt.NDArray[np.timedelta64]
17+
18+
AR_LIKE_b: List[bool]
19+
20+
np.eye(10, M=20.0) # E: No overload variant
21+
np.eye(10, k=2.5, dtype=int) # E: No overload variant
22+
23+
np.diag(AR_b, k=0.5) # E: No overload variant
24+
np.diagflat(AR_b, k=0.5) # E: No overload variant
25+
26+
np.tri(10, M=20.0) # E: No overload variant
27+
np.tri(10, k=2.5, dtype=int) # E: No overload variant
28+
29+
np.tril(AR_b, k=0.5) # E: No overload variant
30+
np.triu(AR_b, k=0.5) # E: No overload variant
31+
32+
np.vander(AR_m) # E: incompatible type
33+
34+
np.histogram2d(AR_m) # E: No overload variant
35+
36+
np.mask_indices(10, func1) # E: incompatible type
37+
np.mask_indices(10, func2, 10.5) # E: incompatible type
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from typing import Any, List, TypeVar
2+
3+
import numpy as np
4+
import numpy.typing as npt
5+
6+
_SCT = TypeVar("_SCT", bound=np.generic)
7+
8+
9+
def func1(ar: npt.NDArray[_SCT], a: int) -> npt.NDArray[_SCT]:
10+
pass
11+
12+
13+
def func2(ar: npt.NDArray[np.number[Any]], a: str) -> npt.NDArray[np.float64]:
14+
pass
15+
16+
17+
AR_b: npt.NDArray[np.bool_]
18+
AR_u: npt.NDArray[np.uint64]
19+
AR_i: npt.NDArray[np.int64]
20+
AR_f: npt.NDArray[np.float64]
21+
AR_c: npt.NDArray[np.complex128]
22+
AR_O: npt.NDArray[np.object_]
23+
24+
AR_LIKE_b: List[bool]
25+
26+
reveal_type(np.fliplr(AR_b)) # E: numpy.ndarray[Any, numpy.dtype[numpy.bool_]]
27+
reveal_type(np.fliplr(AR_LIKE_b)) # E: numpy.ndarray[Any, numpy.dtype[Any]]
28+
29+
reveal_type(np.flipud(AR_b)) # E: numpy.ndarray[Any, numpy.dtype[numpy.bool_]]
30+
reveal_type(np.flipud(AR_LIKE_b)) # E: numpy.ndarray[Any, numpy.dtype[Any]]
31+
32+
reveal_type(np.eye(10)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
33+
reveal_type(np.eye(10, M=20, dtype=np.int64)) # E: numpy.ndarray[Any, numpy.dtype[{int64}]]
34+
reveal_type(np.eye(10, k=2, dtype=int)) # E: numpy.ndarray[Any, numpy.dtype[Any]]
35+
36+
reveal_type(np.diag(AR_b)) # E: numpy.ndarray[Any, numpy.dtype[numpy.bool_]]
37+
reveal_type(np.diag(AR_LIKE_b, k=0)) # E: numpy.ndarray[Any, numpy.dtype[Any]]
38+
39+
reveal_type(np.diagflat(AR_b)) # E: numpy.ndarray[Any, numpy.dtype[numpy.bool_]]
40+
reveal_type(np.diagflat(AR_LIKE_b, k=0)) # E: numpy.ndarray[Any, numpy.dtype[Any]]
41+
42+
reveal_type(np.tri(10)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
43+
reveal_type(np.tri(10, M=20, dtype=np.int64)) # E: numpy.ndarray[Any, numpy.dtype[{int64}]]
44+
reveal_type(np.tri(10, k=2, dtype=int)) # E: numpy.ndarray[Any, numpy.dtype[Any]]
45+
46+
reveal_type(np.tril(AR_b)) # E: numpy.ndarray[Any, numpy.dtype[numpy.bool_]]
47+
reveal_type(np.tril(AR_LIKE_b, k=0)) # E: numpy.ndarray[Any, numpy.dtype[Any]]
48+
49+
reveal_type(np.triu(AR_b)) # E: numpy.ndarray[Any, numpy.dtype[numpy.bool_]]
50+
reveal_type(np.triu(AR_LIKE_b, k=0)) # E: numpy.ndarray[Any, numpy.dtype[Any]]
51+
52+
reveal_type(np.vander(AR_b)) # E: numpy.ndarray[Any, numpy.dtype[numpy.signedinteger[Any]]]
53+
reveal_type(np.vander(AR_u)) # E: numpy.ndarray[Any, numpy.dtype[numpy.signedinteger[Any]]]
54+
reveal_type(np.vander(AR_i, N=2)) # E: numpy.ndarray[Any, numpy.dtype[numpy.signedinteger[Any]]]
55+
reveal_type(np.vander(AR_f, increasing=True)) # E: numpy.ndarray[Any, numpy.dtype[numpy.floating[Any]]]
56+
reveal_type(np.vander(AR_c)) # E: numpy.ndarray[Any, numpy.dtype[numpy.complexfloating[Any, Any]]]
57+
reveal_type(np.vander(AR_O)) # E: numpy.ndarray[Any, numpy.dtype[numpy.object_]]
58+
59+
reveal_type(np.histogram2d(AR_i, AR_b)) # E: Tuple[numpy.ndarray[Any, numpy.dtype[{float64}]], numpy.ndarray[Any, numpy.dtype[numpy.floating[Any]]], numpy.ndarray[Any, numpy.dtype[numpy.floating[Any]]]]
60+
reveal_type(np.histogram2d(AR_f, AR_f)) # E: Tuple[numpy.ndarray[Any, numpy.dtype[{float64}]], numpy.ndarray[Any, numpy.dtype[numpy.floating[Any]]], numpy.ndarray[Any, numpy.dtype[numpy.floating[Any]]]]
61+
reveal_type(np.histogram2d(AR_f, AR_c, weights=AR_LIKE_b)) # E: Tuple[numpy.ndarray[Any, numpy.dtype[{float64}]], numpy.ndarray[Any, numpy.dtype[numpy.complexfloating[Any, Any]]], numpy.ndarray[Any, numpy.dtype[numpy.complexfloating[Any, Any]]]]
62+
63+
reveal_type(np.mask_indices(10, func1)) # E: Tuple[numpy.ndarray[Any, numpy.dtype[{intp}]], numpy.ndarray[Any, numpy.dtype[{intp}]]]
64+
reveal_type(np.mask_indices(8, func2, "0")) # E: Tuple[numpy.ndarray[Any, numpy.dtype[{intp}]], numpy.ndarray[Any, numpy.dtype[{intp}]]]
65+
66+
reveal_type(np.tril_indices(10)) # E: Tuple[numpy.ndarray[Any, numpy.dtype[{int_}]], numpy.ndarray[Any, numpy.dtype[{int_}]]]
67+
68+
reveal_type(np.tril_indices_from(AR_b)) # E: Tuple[numpy.ndarray[Any, numpy.dtype[{int_}]], numpy.ndarray[Any, numpy.dtype[{int_}]]]
69+
70+
reveal_type(np.triu_indices(10)) # E: Tuple[numpy.ndarray[Any, numpy.dtype[{int_}]], numpy.ndarray[Any, numpy.dtype[{int_}]]]
71+
72+
reveal_type(np.triu_indices_from(AR_b)) # E: Tuple[numpy.ndarray[Any, numpy.dtype[{int_}]], numpy.ndarray[Any, numpy.dtype[{int_}]]]

0 commit comments

Comments
 (0)
0