8000 TYP: fix stubtest errors in ``numpy.lib._index_tricks_impl`` · numpy/numpy@b6de917 · GitHub
[go: up one dir, main page]

Skip to content

Commit b6de917

Browse files
jorenhamcharris
authored andcommitted
TYP: fix stubtest errors in numpy.lib._index_tricks_impl
Ported from numpy/numtype#235 --- - move `ndenumerate` and `ndindex` definitions to `lib._index_tricks_impl` - add deprecated `ndenumerate.ndincr` property - removed non-existent `ndenumerate.iter` property - remove incorrect "pass" and "reveal" type-tests for `ndenumerate.iter` - fix incorrect `ndenumerate` constructor fallback return type - fix `AxisConcatenator.makemat` signature
1 parent 4a20c51 commit b6de917

File tree

4 files changed

+149
-174
lines changed

4 files changed

+149
-174
lines changed

numpy/__init__.pyi

Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,8 @@ from numpy.lib._histograms_impl import (
515515
)
516516

517517
from numpy.lib._index_tricks_impl import (
518+
ndenumerate,
519+
ndindex,
518520
ravel_multi_index,
519521
unravel_index,
520522
mgrid,
@@ -4964,50 +4966,6 @@ class errstate:
49644966
) -> None: ...
49654967
def __call__(self, func: _CallableT) -> _CallableT: ...
49664968

4967-
class ndenumerate(Generic[_SCT_co]):
4968-
@property
4969-
def iter(self) -> flatiter[NDArray[_SCT_co]]: ...
4970-
4971-
@overload
4972-
def __new__(
4973-
cls, arr: _FiniteNestedSequence[_SupportsArray[dtype[_SCT]]],
4974-
) -> ndenumerate[_SCT]: ...
4975-
@overload
4976-
def __new__(cls, arr: str | _NestedSequence[str]) -> ndenumerate[str_]: ...
4977-
@overload
4978-
def __new__(cls, arr: bytes | _NestedSequence[bytes]) -> ndenumerate[bytes_]: ...
4979-
@overload
4980-
def __new__(cls, arr: builtins.bool | _NestedSequence[builtins.bool]) -> ndenumerate[np.bool]: ...
4981-
@overload
4982-
def __new__(cls, arr: int | _NestedSequence[int]) -> ndenumerate[int_]: ...
4983-
@overload
4984-
def __new__(cls, arr: float | _NestedSequence[float]) -> ndenumerate[float64]: ...
4985-
@overload
4986-
def __new__(cls, arr: complex | _NestedSequence[complex]) -> ndenumerate[complex128]: ...
4987-
@overload
4988-
def __new__(cls, arr: object) -> ndenumerate[object_]: ...
4989-
4990-
# The first overload is a (semi-)workaround for a mypy bug (tested with v1.10 and v1.11)
4991-
@overload
4992-
def __next__(
4993-
self: ndenumerate[np.bool | datetime64 | timedelta64 | number[Any] | flexible],
4994-
/,
4995-
) -> tuple[_Shape, _SCT_co]: ...
4996-
@overload
4997-
def __next__(self: ndenumerate[object_], /) -> tuple[_Shape, Any]: ...
4998-
@overload
4999-
def __next__(self, /) -> tuple[_Shape, _SCT_co]: ...
5000-
5001-
def __iter__(self) -> Self: ...
5002-
5003-
class ndindex:
5004-
@overload
5005-
def __init__(self, shape: tuple[SupportsIndex, ...], /) -> None: ...
5006-
@overload
5007-
def __init__(self, *shape: SupportsIndex) -> None: ...
5008-
def __iter__(self) -> Self: ...
5009-
def __next__(self) -> _Shape: ...
5010-
50114969
# TODO: The type of each `__next__` and `iters` return-type depends
50124970
# on the length and dtype of `args`; we can't describe this behavior yet
50134971
# as we lack variadics (PEP 646).

numpy/lib/_index_tricks_impl.pyi

Lines changed: 144 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,23 @@
11
from collections.abc import Sequence
2-
from typing import (
3-
Any,
4-
TypeVar,
5-
Generic,
6-
overload,
7-
Literal,
8-
SupportsIndex,
9-
)
2+
from typing import Any, ClassVar, Final, Generic, SupportsIndex, final, overload
3+
from typing import Literal as L
4+
5+
from _typeshed import Incomplete
6+
from typing_extensions import Self, TypeVar, deprecated
107

118
import numpy as np
12-
from numpy import (
13-
# Circumvent a naming conflict with `AxisConcatenator.matrix`
14-
matrix as _Matrix,
15-
ndenumerate,
16-
ndindex,
17-
ndarray,
18-
dtype,
19-
str_,
20-
bytes_,
21-
int_,
22-
float64,
23-
complex128,
24-
)
9+
from numpy._core.multiarray import ravel_multi_index, unravel_index
2510
from numpy._typing import (
26-
# Arrays
2711
ArrayLike,
28-
_NestedSequence,
29-
_FiniteNestedSequence,
3012
NDArray,
31-
32-
# DTypes
33-
DTypeLike,
34-
_SupportsDType,
35-
36-
# Shapes
13+
_FiniteNestedSequence,
14+
_NestedSequence,
3715
_Shape,
16+
_SupportsArray,
17+
_SupportsDType,
3818
)
3919

40-
from numpy._core.multiarray import unravel_index, ravel_multi_index
41-
42-
__all__ = [
20+
__all__ = [ # noqa: RUF022
4321
"ravel_multi_index",
4422
"unravel_index",
4523
"mgrid",
@@ -56,114 +34,163 @@ __all__ = [
5634
"diag_indices_from",
5735
]
5836

37+
###
38+
5939
_T = TypeVar("_T")
60-
_DType = TypeVar("_DType", bound=dtype[Any])
61-
_BoolType = TypeVar("_BoolType", Literal[True], Literal[False])
62-
_TupType = TypeVar("_TupType", bound=tuple[Any, ...])
63-
_ArrayType = TypeVar("_ArrayType", bound=NDArray[Any])
40+
_TupleT = TypeVar("_TupleT", bound=tuple[Any, ...])
41+
_ArrayT = TypeVar("_ArrayT", bound=NDArray[Any])
42+
_DTypeT = TypeVar("_DTypeT", bound=np.dtype[Any])
43+
_ScalarT = TypeVar("_ScalarT", bound=np.generic)
44+
_ScalarT_co = TypeVar("_ScalarT_co", bound=np.generic, covariant=True)
45+
_BoolT_co = TypeVar("_BoolT_co", bound=bool, default=bool, covariant=True)
6446

65-
@overload
66-
def ix_(*args: _FiniteNestedSequence[_SupportsDType[_DType]]) -> tuple[ndarray[_Shape, _DType], ...]: ...
67-
@overload
68-
def ix_(*args: str | _NestedSequence[str]) -> tuple[NDArray[str_], ...]: ...
69-
@overload
70-
def ix_(*args: bytes | _NestedSequence[bytes]) -> tuple[NDArray[bytes_], ...]: ...
71-
@overload
72-
def ix_(*args: bool | _NestedSequence[bool]) -> tuple[NDArray[np.bool], ...]: ...
73-
@overload
74-
def ix_(*args: int | _NestedSequence[int]) -> tuple[NDArray[int_], ...]: ...
75-
@overload
76-
def ix_(*args: float | _NestedSequence[float]) -> tuple[NDArray[float64], ...]: ...
77-
@overload
78-
def ix_(*args: complex | _NestedSequence[complex]) -> tuple[NDArray[complex128], ...]: ...
47+
_AxisT_co = TypeVar("_AxisT_co", bound=int, default=L[0], covariant=True)
48+
_MatrixT_co = TypeVar("_MatrixT_co", bound=bool, default=L[False], covariant=True)
49+
_NDMinT_co = TypeVar("_NDMinT_co", bound=int, default=L[1], covariant=True)
50+
_Trans1DT_co = TypeVar("_Trans1DT_co", bound=int, default=L[-1], covariant=True)
51+
52+
###
7953

80-
class nd_grid(Generic[_BoolType]):
81-
sparse: _BoolType
82-
def __init__(self, sparse: _BoolType = ...) -> None: ...
54+
class ndenumerate(Generic[_ScalarT_co]):
55+
@overload
56+
def __new__(cls, arr: _FiniteNestedSequence[_SupportsArray[np.dtype[_ScalarT]]]) -> ndenumerate[_ScalarT]: ...
57+
@overload
58+
def __new__(cls, arr: str | _NestedSequence[str]) -> ndenumerate[np.str_]: ...
8359
@overload
84-
def __getitem__(
85-
self: nd_grid[Literal[False]],
86-
key: slice | Sequence[slice],
87-
) -> NDArray[Any]: ...
60+
def __new__(cls, arr: bytes | _NestedSequence[bytes]) -> ndenumerate[np.bytes_]: ...
8861
@overload
89-
def __getitem__(
90-
self: nd_grid[Literal[True]],
91-
key: slice | Sequence[slice],
92-
) -> tuple[NDArray[Any], ...]: ...
62+
def __new__(cls, arr: bool | _NestedSequence[bool]) -> ndenumerate[np.bool]: ...
63+
@overload
64+
def __new__(cls, arr: int | _NestedSequence[int]) -> ndenumerate[np.intp]: ...
65+
@overload
66+
def __new__(cls, arr: float | _NestedSequence[float]) -> ndenumerate[np.float64]: ...
67+
@overload
68+
def __new__(cls, arr: complex | _NestedSequence[complex]) -> ndenumerate[np.complex128]: ...
69+
@overload
70+
def __new__(cls, arr: object) -> ndenumerate[Any]: ...
9371

94-
class MGridClass(nd_grid[Literal[False]]):
95-
def __init__(self) -> None: ...
72+
# The first overload is a (semi-)workaround for a mypy bug (tested with v1.10 and v1.11)
73+
@overload
74+
def __next__(
75+
self: ndenumerate[np.bool | np.number | np.flexible | np.datetime64 | np.timedelta64],
76+
/,
77+
) -> tuple[tuple[int, ...], _ScalarT_co]: ...
78+
@overload
79+
def __next__(self: ndenumerate[np.object_], /) -> tuple[tuple[int, ...], Any]: ...
80+
@overload
81+
def __next__(self, /) -> tuple[tuple[int, ...], _ScalarT_co]: ...
82+
83+
#
84+
def __iter__(self) -> Self: ...
85+
86+
class ndindex:
87+
@overload
88+
def __init__(self, shape: tuple[SupportsIndex, ...], /) -> None: ...
89+
@overload
90+
def __init__(self, /, *shape: SupportsIndex) -> None: ...
91+
92+
#
93+
def __iter__(self) -> Self: ...
94+
def __next__(self) -> tuple[int, ...]: ...
9695

97-
mgrid: MGridClass
96+
#
97+
@deprecated("Deprecated since 1.20.0.")
98+
def ndincr(self, /) -> None: ...
9899

99-
class OGridClass(nd_grid[Literal[True]]):
100+
class nd_grid(Generic[_BoolT_co]):
101+
sparse: _BoolT_co
102+
def __init__(self, sparse: _BoolT_co = ...) -> None: ...
103+
@overload
104+
def __getitem__(self: nd_grid[L[False]], key: slice | Sequence[slice]) -> NDArray[Any]: ...
105+
@overload
106+
def __getitem__(self: nd_grid[L[True]], key: slice | Sequence[slice]) -> tuple[NDArray[Any], ...]: ...
107+
108+
@final
109+
class MGridClass(nd_grid[L[False]]):
110+
def __init__(self) -> None: ...
111+
112+
@final
113+
class OGridClass(nd_grid[L[True]]):
100114
def __init__(self) -> None: ...
101115

102-
ogrid: OGridClass
116+
class AxisConcatenator(Generic[_AxisT_co, _MatrixT_co, _NDMinT_co, _Trans1DT_co]):
117+
__slots__ = "axis", "matrix", "ndmin", "trans1d"
118+
119+
makemat: ClassVar[type[np.matrix[tuple[int, int], np.dtype[Any]]]]
103120

104-
class AxisConcatenator:
105-
axis: int
106-
matrix: bool
107-
ndmin: int
108-
trans1d: int
121+
axis: _AxisT_co
122+
matrix: _MatrixT_co
123+
ndmin: _NDMinT_co
124+
trans1d: _Trans1DT_co
125+
126+
#
109127
def __init__(
110128
self,
111-
axis: int = ...,
112-
matrix: bool = ...,
113-
ndmin: int = ...,
114-
trans1d: int = ...,
129+
/,
130+
axis: _AxisT_co = ...,
131+
matrix: _MatrixT_co = ...,
132+
ndmin: _NDMinT_co = ...,
133+
trans1d: _Trans1DT_co = ...,
115134
) -> None: ...
135+
136+
# TODO(jorenham): annotate this
137+
def __getitem__(self, key: Incomplete, /) -> Incomplete: ...
138+
def __len__(self, /) -> L[0]: ...
139+
140+
#
116141
@staticmethod
117142
@overload
118-
def concatenate( # type: ignore[misc]
119-
*a: ArrayLike, axis: SupportsIndex = ..., out: None = ...
120-
) -> NDArray[Any]: ...
143+
def concatenate(*a: ArrayLike, axis: SupportsIndex | None = 0, out: _ArrayT) -> _ArrayT: ...
121144
@staticmethod
122145
@overload
123-
def concatenate(
124-
*a: ArrayLike, axis: SupportsIndex = ..., out: _ArrayType = ...
125-
) -> _ArrayType: ...
126-
@staticmethod
127-
def makemat(
128-
data: ArrayLike, dtype: DTypeLike = ..., copy: bool = ...
129-
) -> _Matrix[Any, Any]: ...
130-
131-
# TODO: Sort out this `__getitem__` method
132-
def __getitem__(self, key: Any) -> Any: ...
133-
134-
class RClass(AxisConcatenator):
135-
axis: Literal[0]
136-
matrix: Literal[False]
137-
ndmin: Literal[1]
138-
trans1d: Literal[-1]
139-
def __init__(self) -> None: ...
146+
def concatenate(*a: ArrayLike, axis: SupportsIndex | None = 0, out: None = None) -> NDArray[Any]: ...
140147

141-
r_: RClass
142-
143-
class CClass(AxisConcatenator):
144-
axis: Literal[-1]
145-
matrix: Literal[False]
146-
ndmin: Literal[2]
147-
trans1d: Literal[0]
148-
def __init__(self) -> None: ...
148+
@final
149+
class RClass(AxisConcatenator[L[0], L[False], L[1], L[-1]]):
150+
def __init__(self, /) -> None: ...
149151

150-
c_: CClass
152+
@final
153+
class CClass(AxisConcatenator[L[-1], L[False], L[2], L[0]]):
154+
def __init__(self, /) -> None: ...
151155

152-
class IndexExpression(Generic[_BoolType]):
153-
maketuple: _BoolType
154-
def __init__(self, maketuple: _BoolType) -> None: ...
156+
class IndexExpression(Generic[_BoolT_co]):
157+
maketuple: _BoolT_co
158+
def __init__(self, maketuple: _BoolT_co) -> None: ...
155159
@overload
156-
def __getitem__(self, item: _TupType) -> _TupType: ... # type: ignore[misc]
160+
def __getitem__(self, item: _TupleT) -> _TupleT: ...
157161
@overload
158-
def __getitem__(self: IndexExpression[Literal[True]], item: _T) -> tuple[_T]: ...
162+
def __getitem__(self: IndexExpression[L[True]], item: _T) -> tuple[_T]: ...
159163
@overload
160-
def __getitem__(self: IndexExpression[Literal[False]], item: _T) -> _T: ...
164+
def __getitem__(self: IndexExpression[L[False]], item: _T) -> _T: ...
165+
166+
@overload
167+
def ix_(*args: _FiniteNestedSequence[_SupportsDType[_DTypeT]]) -> tuple[np.ndarray[_Shape, _DTypeT], ...]: ...
168+
@overload
169+
def ix_(*args: str | _NestedSequence[str]) -> tuple[NDArray[np.str_], ...]: ...
170+
@overload
171+
def ix_(*args: bytes | _NestedSequence[bytes]) -> tuple[NDArray[np.bytes_], ...]: ...
172+
@overload
173+
def ix_(*args: bool | _NestedSequence[bool]) -> tuple[NDArray[np.bool], ...]: ...
174+
@overload
175+
def ix_(*args: int | _NestedSequence[int]) -> tuple[NDArray[np.intp], ...]: ...
176+
@overload
177+
def ix_(*args: float | _NestedSequence[float]) -> tuple[NDArray[np.float64], ...]: ...
178+
@overload
179+
def ix_(*args: complex | _NestedSequence[complex]) -> tuple[NDArray[np.complex128], ...]: ...
180+
181+
#
182+
def fill_diagonal(a: NDArray[Any], val: object, wrap: bool = ...) -> None: ...
183+
184+
#
185+
def diag_indices(n: int, ndim: int = ...) -> tuple[NDArray[np.intp], ...]: ...
186+
def diag_indices_from(arr: ArrayLike) -> tuple[NDArray[np.intp], ...]: ...
161187

162-
index_exp: IndexExpression[Literal[True]]
163-
s_: IndexExpression[Literal[False]]
188+
#
189+
mgrid: Final[MGridClass] = ...
190+
ogrid: Final[OGridClass] = ...
164191

165-
def fill_diagonal(a: NDArray[Any], val: Any, wrap: bool = ...) -> None: ...
166-
def diag_indices(n: int, ndim: int = ...) -> tuple[NDArray[int_], ...]: ...
167-
def diag_indices_from(arr: ArrayLike) -> tuple[NDArray[int_], ...]: ...
192+
r_: Final[RClass] = ...
193+
c_: Final[CClass] = ...
168194

169-
# NOTE: see `numpy/__init__.pyi` for `ndenumerate` and `ndindex`
195+
index_exp: Final[IndexExpression[L[True]]] = ...
196+
s_: Final[IndexExpression[L[False]]] = ...

numpy/typing/tests/data/pass/index_tricks.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@
1313
np.ndenumerate(AR_LIKE_f)
1414
np.ndenumerate(AR_LIKE_U)
1515

16-
np.ndenumerate(AR_i8).iter
17-
np.ndenumerate(AR_LIKE_f).iter
18-
np.ndenumerate(AR_LIKE_U).iter
19-
2016
next(np.ndenumerate(AR_i8))
2117
next(np.ndenumerate(AR_LIKE_f))
2218
next(np.ndenumerate(AR_LIKE_U))

0 commit comments

Comments
 (0)
0