8000 TYP: np.argmin changes · numpy/numpy@da7423a · GitHub
[go: up one dir, main page]

Skip to content

Commit da7423a

Browse files
committed
TYP: np.argmin changes
1 parent 00f2733 commit da7423a

File tree

2 files changed

+35
-56
lines changed

2 files changed

+35
-56
lines changed

numpy/_core/fromnumeric.pyi

Lines changed: 19 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ _NumberOrObjectT = TypeVar("_NumberOrObjectT", bound=np.number | np.object_)
111111
_ArrayT = TypeVar("_ArrayT", bound=np.ndarray[Any, Any])
112112
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
113113
_ShapeT_co = TypeVar("_ShapeT_co", bound=tuple[int, ...], covariant=True)
114+
_IndexArray = NDArray[np.signedinteger] | NDArray[np.unsignedinteger] | NDArray[np.bool_] # type alias for argmin / argmax
115+
_OutT = TypeVar("_OutT", bound=_IndexArray) # Type variable, must be assignable to _IndexArray
114116

115117
@type_check_only
116118
class _SupportsShape(Protocol[_ShapeT_co]):
@@ -401,77 +403,38 @@ def argsort(
401403
@overload
402404
def argmax(
403405
a: ArrayLike,
404-
axis: None = ...,
405-
out: None = ...,
406-
*,
407-
keepdims: Literal[False] = ...,
408-
) -> intp: ...
409-
@overload
410-
def argmax(
411-
a: ArrayLike,
412-
axis: SupportsIndex | None = ...,
413-
out: None = ...,
406+
axis: int | None = ...,
414407
*,
415-
keepdims: bool = ...,
416-
) -> Any: ...
417-
@overload
418-
def argmax(
419-
a: ArrayLike,
420-
axis: SupportsIndex | None,
421-
out: _ArrayT,
422-
*,
423-
keepdims: bool = ...,
424-
) -> _ArrayT: ...
408+
keepdims: bool = ...
409+
) -> np.integer: ...
410+
425411
@overload
426412
def argmax(
427413
a: ArrayLike,
428-
axis: SupportsIndex | None = ...,
414+
axis: int | None = ...,
429415
*,
430-
out: _ArrayT,
431-
keepdims: bool = ...,
432-
) -> _ArrayT: ...
416+
out: _OutT,
417+
keepdims: bool = ...
418+
) -> _OutT: ...
433419

434420
@overload
435421
def argmin(
436422
a: ArrayLike,
437-
axis: None = ...,
438-
out: None = ...,
439-
*,
440-
keepdims: Literal[False] = ...,
441-
) -> intp: ...
442-
@overload
443-
def argmin(
444-
a: ArrayLike,
445-
axis: SupportsIndex | None = ...,
446-
out: None = ...,
447-
*,
448-
keepdims: bool = ...,
449-
) -> Any: ...
450-
@overload
451-
def argmin(
452-
a: ArrayLike,
453-
axis: SupportsIndex | None,
454-
out: _ArrayT,
423+
axis: int | None = ...,
455424
*,
456-
keepdims: bool = ...,
457-
) -> _ArrayT: ...
425+
keepdims: bool = ...
426+
) -> np.integer: ...
427+
458428
@overload
459429
def argmin(
460430
a: ArrayLike,
461-
axis: SupportsIndex | None = ...,
431+
axis: int | None = ...,
462432
*,
463-
out: _ArrayT,
464-
keepdims: bool = ...,
465-
) -> _ArrayT: ...
433+
out: _OutT,
434+
keepdims: bool = ...
435+
) -> _OutT: ...
466436

467-
@overload
468-
def searchsorted(
469-
a: ArrayLike,
470-
v: _ScalarLike_co,
471-
side: _SortSide = ...,
472-
sorter: _ArrayLikeInt_co | None = ..., # 1D int array
473-
) -> intp: ...
474-
@overload
437+
overload
475438
def searchsorted(
476439
a: ArrayLike,
477440
v: ArrayLike,
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import numpy as np
2+
from numpy.typing import NDArray
3+
4+
arr = np.array([1, 2, 3])
5+
6+
out_int: NDArray[np.int64] = np.empty(())
7+
_ = np.argmin(arr, out=out_int) # expected to pass
8+
_ = np.argmax(arr, out=out_int) # expected to pass
9+
10+
out_bool: NDArray[np.bool_] = np.empty(())
11+
_ = np.argmin(arr, out=out_bool) # expected to pass
12+
_ = np.argmax(arr, out=out_bool) # expected to pass
13+
14+
out_bad: NDArray[np.float64] = np.empty(())
15+
_ = np.argmin(arr, out=out_bad) # should fail static typing
16+
_ = np.argmax(arr, out=out_bad) # should fail static typing

0 commit comments

Comments
 (0)
0