8000 TYP: np.argmin changes · numpy/numpy@6842d13 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6842d13

Browse files
committed
TYP: np.argmin changes
1 parent 00f2733 commit 6842d13

File tree

3 files changed

+22
-16
lines changed

3 files changed

+22
-16
lines changed

numpy/__init__.pyi

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,7 @@ _CharacterItemT_co = TypeVar("_CharacterItemT_co", bound=_CharLike_co, default=_
829829
_TD64ItemT_co = TypeVar("_TD64ItemT_co", bound=dt.timedelta | int | None, default=dt.timedelta | int | None, covariant=True)
830830
_DT64ItemT_co = TypeVar("_DT64ItemT_co", bound=dt.date | int | None, default=dt.date | int | None, covariant=True)
831831
_TD64UnitT = TypeVar("_TD64UnitT", bound=_TD64Unit, default=_TD64Unit)
832+
_BoolOrIntArrayT = TypeVar("_BoolOrIntArrayT", bound=NDArray[np.integer | np.bool])
832833

833834
### Type Aliases (for internal use only)
834835

@@ -1707,18 +1708,18 @@ class _ArrayOrScalarCommon:
17071708
@overload # axis=index, out=None (default)
17081709
def argmax(self, /, axis: SupportsIndex, out: None = None, *, keepdims: builtins.bool = False) -> Any: ...
17091710
@overload # axis=index, out=ndarray
1710-
def argmax(self, /, axis: SupportsIndex | None, out: _ArrayT, *, keepdims: builtins.bool = False) -> _ArrayT: ...
1711+
def argmax(self, /, axis: SupportsIndex | None, out: _BoolOrIntArrayT, *, keepdims: builtins.bool = False) -> _BoolOrIntArrayT: ...
17111712
@overload
1712-
def argmax(self, /, axis: SupportsIndex | None = None, *, out: _ArrayT, keepdims: builtins.bool = False) -> _ArrayT: ...
1713+
def argmax(self, /, axis: SupportsIndex | None = None, *, out: _BoolOrIntArrayT, keepdims: builtins.bool = False) -> _BoolOrIntArrayT: ...
17131714

17141715
@overload # axis=None (default), out=None (default), keepdims=False (default)
17151716
def argmin(self, /, axis: None = None, out: None = None, *, keepdims: L[False] = False) -> intp: ...
17161717
@overload # axis=index, out=None (default)
17171718
def argmin(self, /, axis: SupportsIndex, out: None = None, *, keepdims: builtins.bool = False) -> Any: ...
17181719
@overload # axis=index, out=ndarray
1719-
def argmin(self, /, axis: SupportsIndex | None, out: _ArrayT, *, keepdims: builtins.bool = False) -> _ArrayT: ...
1720+
def argmin(self, /, axis: SupportsIndex | None, out: _BoolOrIntArrayT, *, keepdims: builtins.bool = False) -> _BoolOrIntArrayT: ...
17201721
@overload
1721-
def argmin(self, /, axis: SupportsIndex | None = None, *, out: _ArrayT, keepdims: builtins.bool = False) -> _ArrayT: ...
1722+
def argmin(self, /, axis: SupportsIndex | None = None, *, out: _BoolOrIntArrayT, keepdims: builtins.bool = False) -> _BoolOrIntArrayT: ...
17221723

17231724
@overload # out=None (default)
17241725
def round(self, /, decimals: SupportsIndex = 0, out: None = None) -> Self: ...
@@ -5363,14 +5364,14 @@ class matrix(ndarray[_2DShapeT_co, _DTypeT_co]):
53635364
@overload
53645365
def argmax(self, axis: _ShapeLike, out: None = ...) -> matrix[_2D, dtype[intp]]: ...
53655366
@overload
5366-
def argmax(self, axis: _ShapeLike | None = ..., out: _ArrayT = ...) -> _ArrayT: ...
5367+
def argmax(self, axis: _ShapeLike | None = ..., out: _BoolOrIntArrayT = ...) -> _BoolOrIntArrayT: ...
53675368

53685369
@overload
53695370
def argmin(self: NDArray[_ScalarT], axis: None = ..., out: None = ...) -> intp: ...
53705371
@overload
53715372
def argmin(self, axis: _ShapeLike, out: None = ...) -> matrix[_2D, dtype[intp]]: ...
53725373
@overload
5373-
def argmin(self, axis: _ShapeLike | None = ..., out: _ArrayT = ...) -> _ArrayT: ...
5374+
def argmin(self, axis: _ShapeLike | None = ..., out: _BoolOrIntArrayT = ...) -> _BoolOrIntArrayT: ...
53745375

53755376
@overload
53765377
def ptp(self: NDArray[_ScalarT], axis: None = ..., out: None = ...) -> _ScalarT: ...

numpy/_core/fromnumeric.pyi

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ _NumberOrObjectT = TypeVar("_NumberOrObjectT", bound=np.number | np.object_)
111111
_ArrayT = TypeVar("_ArrayT", bound=np.ndarray[Any, Any])
112112
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
113113
_ShapeT_co = TypeVar("_ShapeT_co", bound=tuple[int, ...], covariant=True)
114+
_BoolOrIntArrayT = TypeVar("_BoolOrIntArrayT", bound=NDArray[np.integer | np.bool])
114115

115116
@type_check_only
116117
class _SupportsShape(Protocol[_ShapeT_co]):
@@ -418,18 +419,18 @@ def argmax(
418419
def argmax(
419420
a: ArrayLike,
420421
axis: SupportsIndex | None,
421-
out: _ArrayT,
422+
out: _BoolOrIntArrayT,
422423
*,
423424
keepdims: bool = ...,
424-
) -> _ArrayT: ...
425+
) -> _BoolOrIntArrayT: ...
425426
@overload
426427
def argmax(
427428
a: ArrayLike,
428429
axis: SupportsIndex | None = ...,
429430
*,
430-
out: _ArrayT,
431+
out: _BoolOrIntArrayT,
431432
keepdims: bool = ...,
432-
) -> _ArrayT: ...
433+
) -> _BoolOrIntArrayT: ...
433434

434435
@overload
435436
def argmin(
@@ -451,18 +452,18 @@ def argmin(
451< 628C /code>452
def argmin(
452453
a: ArrayLike,
453454
axis: SupportsIndex | None,
454-
out: _ArrayT,
455+
out: _BoolOrIntArrayT,
455456
*,
456457
keepdims: bool = ...,
457-
) -> _ArrayT: ...
458+
) -> _BoolOrIntArrayT: ...
458459
@overload
459460
def argmin(
460461
a: ArrayLike,
461462
axis: SupportsIndex | None = ...,
462463
*,
463-
out: _ArrayT,
464+
out: _BoolOrIntArrayT,
464465
keepdims: bool = ...,
465-
) -> _ArrayT: ...
466+
) -> _BoolOrIntArrayT: ...
466467

467468
@overload
468469
def searchsorted(

numpy/typing/tests/data/reveal/fromnumeric.pyi

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ f4: np.float32
2525
i8: np.int64
2626
f: float
2727

28+
# integer‑dtype subclass for argmin/argmax
29+
class NDArrayIntSubclass(npt.NDArray[np.intp]): ...
30+
AR_sub_i: NDArrayIntSubclass
31+
2832
assert_type(np.take(b, 0), np.bool)
2933
assert_type(np.take(f4, 0), np.float32)
3034
assert_type(np.take(f, 0), Any)
@@ -89,13 +93,13 @@ assert_type(np.argmax(AR_b), np.intp)
8993
assert_type(np.argmax(AR_f4), np.intp)
9094
assert_type(np.argmax(AR_b, axis=0), Any)
9195
assert_type(np.argmax(AR_f4, axis=0), Any)
92-
assert_type(np.argmax(AR_f4, out=AR_subclass), NDArraySubclass)
96+
assert_type(np.argmax(AR_f4, out=AR_sub_i), NDArrayIntSubclass)
9397

9498
assert_type(np.argmin(AR_b), np.intp)
9599
assert_type(np.argmin(AR_f4), np.intp)
96100
assert_type(np.argmin(AR_b, axis=0), Any)
97101
assert_type(np.argmin(AR_f4, axis=0), Any)
98-
assert_type(np.argmin(AR_f4, out=AR_subclass), NDArraySubclass)
102+
assert_type(np.argmin(AR_f4, out=AR_sub_i), NDArrayIntSubclass)
99103

100104
assert_type(np.searchsorted(AR_b[0], 0), np.intp)
101105
assert_type(np.searchsorted(AR_f4[0], 0), np.intp)

0 commit comments

Comments
 (0)
0