8000 TYP: np.argmin changes · numpy/numpy@e536b84 · GitHub
[go: up one dir, main page]

Skip to content

Commit e536b84

Browse files
committed
TYP: np.argmin changes
1 parent 00f2733 commit e536b84

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

numpy/_core/fromnumeric.pyi

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ _NumberOrObjectT = TypeVar("_NumberOrObjectT", bound=np.number | np.object_)
111111
_ArrayT = TypeVar("_ArrayT", bound=np.ndarray[Any, Any])
112112
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
113113
_ShapeT_co = TypeVar("_ShapeT_co", bound=tuple[int, ...], covariant=True)
114+
_BoolOrIntArrayT = TypeVar("_BoolOrIntArrayT", bound=NDArray[np.integer | np.bool])
114115

115116
@type_check_only
116117
class _SupportsShape(Protocol[_ShapeT_co]):
@@ -418,18 +419,18 @@ def argmax(
418419
def argmax(
419420
a: ArrayLike,
420421
axis: SupportsIndex | None,
421-
out: _ArrayT,
422+
out: _BoolOrIntArrayT,
422423
*,
423424
keepdims: bool = ...,
424-
) -> _ArrayT: ...
425+
) -> _BoolOrIntArrayT: ...
425426
@overload
426427
def argmax(
427428
a: ArrayLike,
428429
axis: SupportsIndex | None = ...,
429430
*,
430-
out: _ArrayT,
431+
out: _BoolOrIntArrayT,
431432
keepdims: bool = ...,
432-
) -> _ArrayT: ...
433+
) -> _BoolOrIntArrayT: ...
433434

434435
@overload
435436
def argmin(
@@ -451,18 +452,18 @@ def argmin(
451452
def argmin(
452453
a: ArrayLike,
453454
axis: SupportsIndex | None,
454-
out: _ArrayT,
455+
out: _BoolOrIntArrayT,
455456
*,
456457
keepdims: bool = ...,
457-
) -> _ArrayT: ...
458+
) -> _BoolOrIntArrayT: ...
458459
@overload
459460
def argmin(
460461
a: ArrayLike,
461462
axis: SupportsIndex | None = ...,
462463
*,
463-
out: _ArrayT,
464+
out: _BoolOrIntArrayT,
464465
keepdims: bool = ...,
465-
) -> _ArrayT: ...
466+
) -> _BoolOrIntArrayT: ...
466467

467468
@overload
468469
def searchsorted(

numpy/typing/tests/data/reveal/fromnumeric.pyi

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ f4: np.float32
2525
i8: np.int64
2626
f: float
2727

28+
# integer‑dtype subclass for argmin/argmax
29+
class NDArrayIntSubclass(npt.NDArray[np.intp]): ...
30+
AR_sub_i: NDArrayIntSubclass
31+
2832
assert_type(np.take(b, 0), np.bool)
2933
assert_type(np.take(f4, 0), np.float32)
3034
assert_type(np.take(f, 0), Any)
@@ -89,13 +93,19 @@ assert_type(np.argmax(AR_b), np.intp)
8993
assert_type(np.argmax(AR_f4), np.intp)
9094
assert_type(np.argmax(AR_b, axis=0), Any)
9195
assert_type(np.argmax(AR_f4, axis=0), Any)
92-
assert_type(np.argmax(AR_f4, out=AR_subclass), NDArraySubclass)
96+
assert_type(np.argmax(AR_f4, out=AR_sub_i), NDArrayIntSubclass)
9397

9498
assert_type(np.argmin(AR_b), np.intp)
9599
assert_type(np.argmin(AR_f4), np.intp)
96100
assert_type(np.argmin(AR_b, axis=0), Any)
97101
assert_type(np.argmin(AR_f4, axis=0), Any)
98-
assert_type(np.argmin(AR_f4, out=AR_subclass), NDArraySubclass)
102+
assert_type(np.argmin(AR_f4, out=AR_sub_i), NDArrayIntSubclass)
103+
104+
# assert_type(np.argmax(AR_f4, out=AR_sub_i), NDArraySubclass2)
105+
# assert_type(np.argmin(AR_f4, out=AR_sub_i), NDArraySubclass2)
106+
107+
# assert_type(np.argmin(AR_f4, out=AR_i8), npt.NDArray[np.intp])
108+
# assert_type(np.argmax(AR_f4, out=AR_i8), npt.NDArray[np.intp])
99109

100110
assert_type(np.searchsorted(AR_b[0], 0), np.intp)
101111
assert_type(np.searchsorted(AR_f4[0], 0), np.intp)

0 commit comments

Comments
 (0)
0