8000 Merge pull request #28906 from lvllvl/issue-28641 · numpy/numpy@7beea21 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7beea21

Browse files
authored
Merge pull request #28906 from lvllvl/issue-28641
2 parents 5fe514b + cd6b1d3 commit 7beea21

File tree

7 files changed

+36
-25
lines changed

7 files changed

+36
-25
lines changed

numpy/__init__.pyi

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,7 @@ _CharacterItemT_co = TypeVar("_CharacterItemT_co", bound=_CharLike_co, default=_
826826
_TD64ItemT_co = TypeVar("_TD64ItemT_co", bound=dt.timedelta | int | None, default=dt.timedelta | int | None, covariant=True)
827827
_DT64ItemT_co = TypeVar("_DT64ItemT_co", bound=dt.date | int | None, default=dt.date | int | None, covariant=True)
828828
_TD64UnitT = TypeVar("_TD64UnitT", bound=_TD64Unit, default=_TD64Unit)
829+
_BoolOrIntArrayT = TypeVar("_BoolOrIntArrayT", bound=NDArray[integer | np.bool])
829830

830831
### Type Aliases (for internal use only)
831832

@@ -1704,18 +1705,18 @@ class _ArrayOrScalarCommon:
17041705
@overload # axis=index, out=None (default)
17051706
def argmax(self, /, axis: SupportsIndex, out: None = None, *, keepdims: builtins.bool = False) -> Any: ...
17061707
@overload # axis=index, out=ndarray
1707-
def argmax(self, /, axis: SupportsIndex | None, out: _ArrayT, *, keepdims: builtins.bool = False) -> _ArrayT: ...
1708+
def argmax(self, /, axis: SupportsIndex | None, out: _BoolOrIntArrayT, *, keepdims: builtins.bool = False) -> _BoolOrIntArrayT: ...
17081709
@overload
1709-
def argmax(self, /, axis: SupportsIndex | None = None, *, out: _ArrayT, keepdims: builtins.bool = False) -> _ArrayT: ...
1710+
def argmax(self, /, axis: SupportsIndex | None = None, *, out: _BoolOrIntArrayT, keepdims: builtins.bool = False) -> _BoolOrIntArrayT: ...
17101711

17111712
@overload # axis=None (default), out=None (default), keepdims=False (default)
17121713
def argmin(self, /, axis: None = None, out: None = None, *, keepdims: L[False] = False) -> intp: ...
17131714
@overload # axis=index, out=None (default)
17141715
def argmin(self, /, axis: SupportsIndex, out: None = None, *, keepdims: builtins.bool = False) -> Any: ...
17151716
@overload # axis=index, out=ndarray
1716-
def argmin(self, /, axis: SupportsIndex | None, out: _ArrayT, *, keepdims: builtins.bool = False) -> _ArrayT: ...
1717+
def argmin(self, /, axis: SupportsIndex | None, out: _BoolOrIntArrayT, *, keepdims: builtins.bool = False) -> _BoolOrIntArrayT: ...
17171718
@overload
1718-
def argmin(self, /, axis: SupportsIndex | None = None, *, out: _ArrayT, keepdims: builtins.bool = False) -> _ArrayT: ...
1719+
def argmin(self, /, axis: SupportsIndex | None = None, *, out: _BoolOrIntArrayT, keepdims: builtins.bool = False) -> _BoolOrIntArrayT: ...
17191720

17201721
@overload # out=None (default)
17211722
def round(self, /, decimals: SupportsIndex = 0, out: None = None) -> Self: ...
@@ -5364,19 +5365,19 @@ class matrix(ndarray[_2DShapeT_co, _DTypeT_co]):
53645365
@overload
53655366
def argmax(self, axis: _ShapeLike, out: None = None) -> matrix[_2D, dtype[intp]]: ...
53665367
@overload
5367-
def argmax(self, axis: _ShapeLike | None, out: _ArrayT) -> _ArrayT: ...
5368+
def argmax(self, axis: _ShapeLike | None, out: _BoolOrIntArrayT) -> _BoolOrIntArrayT: ...
53685369
@overload
5369-
def argmax(self, axis: _ShapeLike | None = None, *, out: _ArrayT) -> _ArrayT: ... # pyright: ignore[reportIncompatibleMethodOverride]
5370+
def argmax(self, axis: _ShapeLike | None = None, *, out: _BoolOrIntArrayT) -> _BoolOrIntArrayT: ... # pyright: ignore[reportIncompatibleMethodOverride]
53705371

53715372
# keep in sync with `argmax`
53725373
@overload # type: ignore[override]
53735374
def argmin(self: NDArray[_ScalarT], axis: None = None, out: None = None) -> intp: ...
53745375
@overload
53755376
def argmin(self, axis: _ShapeLike, out: None = None) -> matrix[_2D, dtype[intp]]: ...
53765377
@overload
5377-
def argmin(self, axis: _ShapeLike | None, out: _ArrayT) -> _ArrayT: ...
5378+
def argmin(self, axis: _ShapeLike | None, out: _BoolOrIntArrayT) -> _BoolOrIntArrayT: ...
53785379
@overload
5379-
def argmin(self, axis: _ShapeLike | None = None, *, out: _ArrayT) -> _ArrayT: ... # pyright: ignore[reportIncompatibleMethodOverride]
5380+
def argmin(self, axis: _ShapeLike | None = None, *, out: _BoolOrIntArrayT) -> _BoolOrIntArrayT: ... # pyright: ignore[reportIncompatibleMethodOverride]
53805381

53815382
#the second overload handles the (rare) case that the matrix is not 2-d
53825383
@overload

numpy/_core/fromnumeric.pyi

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ _NumberOrObjectT = TypeVar("_NumberOrObjectT", bound=np.number | np.object_)
111111
_ArrayT = TypeVar("_ArrayT", bound=np.ndarray[Any, Any])
112112
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
113113
_ShapeT_co = TypeVar("_ShapeT_co", bound=tuple[int, ...], covariant=True)
114+
_BoolOrIntArrayT = TypeVar("_BoolOrIntArrayT", bound=NDArray[np.integer | np.bool])
114115

115116
@type_check_only
116117
class _SupportsShape(Protocol[_ShapeT_co]):
@@ -418,18 +419,18 @@ def argmax(
418419
def argmax(
419420
a: ArrayLike,
420421
axis: SupportsIndex | None,
421-
out: _ArrayT,
422+
out: _BoolOrIntArrayT,
422423
*,
423424
keepdims: bool = ...,
424-
) -> _ArrayT: ...
425+
) -> _BoolOrIntArrayT: ...
425426
@overload
426427
def argmax(
427428
a: Ar 8000 rayLike,
428429
axis: SupportsIndex | None = ...,
429430
*,
430-
out: _ArrayT,
431+
out: _BoolOrIntArrayT,
431432
keepdims: bool = ...,
432-
) -> _ArrayT: ...
433+
) -> _BoolOrIntArrayT: ...
433434

434435
@overload
435436
def argmin(
@@ -451,18 +452,18 @@ def argmin(
451452
def argmin(
452453
a: ArrayLike,
453454
axis: SupportsIndex | None,
454-
out: _ArrayT,
455+
out: _BoolOrIntArrayT,
455456
*,
456457
keepdims: bool = ...,
457-
) -> _ArrayT: ...
458+
) -> _BoolOrIntArrayT: ...
458459
@overload
459460
def argmin(
460461
a: ArrayLike,
461462
axis: SupportsIndex | None = ...,
462463
*,
463-
out: _ArrayT,
464+
out: _BoolOrIntArrayT,
464465
keepdims: bool = ...,
465-
) -> _ArrayT: ...
466+
) -> _BoolOrIntArrayT: ...
466467

467468
@overload
468469
def searchsorted(

numpy/typing/tests/data/fail/fromnumeric.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ A = np.array(True, ndmin=2, dtype=bool)
77
A.setflags(write=False)
88
AR_U: npt.NDArray[np.str_]
99
AR_M: npt.NDArray[np.datetime64]
10+
AR_f4: npt.NDArray[np.float32]
1011

1112
a = np.bool(True)
1213

@@ -50,9 +51,11 @@ np.argsort(A, order=range(5)) # type: ignore[arg-type]
5051

5152
np.argmax(A, axis="bob") # type: ignore[call-overload]
5253
np.argmax(A, kind="bob") # type: ignore[call-overload]
54+
np.argmax(A, out=AR_f4) # type: ignore[type-var]
5355

5456
np.argmin(A, axis="bob") # type: ignore[call-overload]
5557
np.argmin(A, kind="bob") # type: ignore[call-overload]
58+
np.argmin(A, out=AR_f4) # type: ignore[type-var]
5659

5760
np.searchsorted(A[0], 0, side="bob") # type: ignore[call-overload]
5861
np.searchsorted(A[0], 0, sorter=1.0) # type: ignore[call-overload]

numpy/typing/tests/data/pass/ndarray_misc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
import numpy.typing as npt
1616

1717
class SubClass(npt.NDArray[np.float64]): ...
18-
18+
class IntSubClass(npt.NDArray[np.intp]): ...
1919

2020
i4 = np.int32(1)
2121
A: np.ndarray[Any, np.dtype[np.int32]] = np.array([[1]], dtype=np.int32)
2222
B0 = np.empty((), dtype=np.int32).view(SubClass)
2323
B1 = np.empty((1,), dtype=np.int32).view(SubClass)
2424
B2 = np.empty((1, 1), dtype=np.int32).view(SubClass)
25+
B_int0: IntSubClass = np.empty((), dtype=np.intp).view(IntSubClass)
2526
C: np.ndarray[Any, np.dtype[np.int32]] = np.array([0, 1, 2], dtype=np.int32)
2627
D = np.ones(3).view(SubClass)
2728

@@ -42,12 +43,12 @@ class SubClass(npt.NDArray[np.float64]): ...
4243
i4.argmax()
4344
A.argmax()
4445
A.argmax(axis=0)
45-
A.argmax(out=B0)
46+
A.argmax(out=B_int0)
4647

4748
i4.argmin()
4849
A.argmin()
4950
A.argmin(axis=0)
50-
A.argmin(out=B0)
51+
A.argmin(out=B_int0)
5152

5253
i4.argsort()
5354
A.argsort()

numpy/typing/tests/data/reveal/fromnumeric.pyi

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ f4: np.float32
2525
i8: np.int64
2626
f: float
2727

28+
# integer‑dtype subclass for argmin/argmax
29+
class NDArrayIntSubclass(npt.NDArray[np.intp]): ...
30+
AR_sub_i: NDArrayIntSubclass
31+
2832
assert_type(np.take(b, 0), np.bool)
2933
assert_type(np.take(f4, 0), np.float32)
3034
assert_type(np.take(f, 0), Any)
@@ -89,13 +93,13 @@ assert_type(np.argmax(AR_b), np.intp)
8993
assert_type(np.argmax(AR_f4), np.intp)
9094
assert_type(np.argmax(AR_b, axis=0), Any)
9195
assert_type(np.argmax(AR_f4, axis=0), Any)
92-
assert_type(np.argmax(AR_f4, out=AR_subclass), NDArraySubclass)
96+
assert_type(np.argmax(AR_f4, out=AR_sub_i), NDArrayIntSubclass)
9397

9498
assert_type(np.argmin(AR_b), np.intp)
9599
assert_type(np.argmin(AR_f4), np.intp)
96100
assert_type(np.argmin(AR_b, axis=0), Any)
97101
assert_type(np.argmin(AR_f4, axis=0), Any)
98-
assert_type(np.argmin(AR_f4, out=AR_subclass), NDArraySubclass)
102+
assert_type(np.argmin(AR_f4, out=AR_sub_i), NDArrayIntSubclass)
99103

100104
assert_type(np.searchsorted(AR_b[0], 0), np.intp)
101105
assert_type(np.searchsorted(AR_f4[0], 0), np.intp)

numpy/typing/tests/data/reveal/matrix.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ _Shape2D: TypeAlias = tuple[int, int]
77

88
mat: np.matrix[_Shape2D, np.dtype[np.int64]]
99
ar_f8: npt.NDArray[np.float64]
10+
ar_ip: npt.NDArray[np.intp]
1011

1112
assert_type(mat * 5, np.matrix[_Shape2D, Any])
1213
assert_type(5 * mat, np.matrix[_Shape2D, Any])
@@ -50,8 +51,8 @@ assert_type(mat.any(out=ar_f8), npt.NDArray[np.float64])
5051
assert_type(mat.all(out=ar_f8), npt.NDArray[np.float64])
5152
assert_type(mat.max(out=ar_f8), npt.NDArray[np.float64])
5253
assert_type(mat.min(out=ar_f8), npt.NDArray[np.float64])
53-
assert_type(mat.argmax(out=ar_f8), npt.NDArray[np.float64])
54-
assert_type(mat.argmin(out=ar_f8), npt.NDArray[np.float64])
54+
assert_type(mat.argmax(out=ar_ip), npt.NDArray[np.intp])
55+
assert_type(mat.argmin(out=ar_ip), npt.NDArray[np.intp])
5556
assert_type(mat.ptp(out=ar_f8), npt.NDArray[np.float64])
5657

5758
assert_type(mat.T, np.matrix[_Shape2D, np.dtype[np.int64]])

numpy/typing/tests/data/reveal/ndarray_misc.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ assert_type(AR_f8.any(out=B), SubClass)
5858
assert_type(f8.argmax(), np.intp)
5959
assert_type(AR_f8.argmax(), np.intp)
6060
assert_type(AR_f8.argmax(axis=0), Any)
61-
assert_type(AR_f8.argmax(out=B), SubClass)
61+
assert_type(AR_f8.argmax(out=AR_i8), npt.NDArray[np.intp])
6262

6363
assert_type(f8.argmin(), np.intp)
6464
assert_type(AR_f8.argmin(), np.intp)
6565
assert_type(AR_f8.argmin(axis=0), Any)
66-
assert_type(AR_f8.argmin(out=B), SubClass)
66+
assert_type(AR_f8.argmin(out=AR_i8), npt.NDArray[np.intp])
6767

6868
assert_type(f8.argsort(), npt.NDArray[Any])
6969
assert_type(AR_f8.argsort(), npt.NDArray[Any])

0 commit comments

Comments
 (0)
0