@@ -111,6 +111,8 @@ _NumberOrObjectT = TypeVar("_NumberOrObjectT", bound=np.number | np.object_)
111
111
_ArrayT = TypeVar ("_ArrayT" , bound = np .ndarray [Any , Any ])
112
112
_ShapeT = TypeVar ("_ShapeT" , bound = tuple [int , ...])
113
113
_ShapeT_co = TypeVar ("_ShapeT_co" , bound = tuple [int , ...], covariant = True )
114
+ _IndexArray = NDArray [np .signedinteger ] | NDArray [np .unsignedinteger ] | NDArray [np .bool_ ] # type alias for argmin / argmax
115
+ _OutT = TypeVar ("_OutT" , bound = _IndexArray ) # Type variable, must be assignable to _IndexArray
114
116
115
117
@type_check_only
116
118
class _SupportsShape (Protocol [_ShapeT_co ]):
@@ -401,77 +403,38 @@ def argsort(
401
403
@overload
402
404
def argmax (
403
405
a : ArrayLike ,
404
- axis : None = ...,
405
- out : None = ...,
406
- * ,
407
- keepdims : Literal [False ] = ...,
408
- ) -> intp : ...
409
- @overload
410
- def argmax (
411
- a : ArrayLike ,
412
- axis : SupportsIndex | None = ...,
413
- out : None = ...,
406
+ axis : int | None = ...,
414
407
* ,
415
- keepdims : bool = ...,
416
- ) -> Any : ...
417
- @overload
418
- def argmax (
419
- a : ArrayLike ,
420
- axis : SupportsIndex | None ,
421
- out : _ArrayT ,
422
- * ,
423
- keepdims : bool = ...,
424
- ) -> _ArrayT : ...
408
+ keepdims : bool = ...
409
+ ) -> np .integer : ...
410
+
425
411
@overload
426
412
def argmax (
427
413
a : ArrayLike ,
428
- axis : SupportsIndex | None = ...,
414
+ axis : int | None = ...,
429
415
* ,
430
- out : _ArrayT ,
431
- keepdims : bool = ...,
432
- ) -> _ArrayT : ...
416
+ out : _OutT ,
417
+ keepdims : bool = ...
418
+ ) -> _OutT : ...
433
419
434
420
@overload
435
421
def argmin (
436
422
a : ArrayLike ,
437
- axis : None = ...,
438
- out : None = ...,
439
- * ,
440
- keepdims : Literal [False ] = ...,
441
- ) -> intp : ...
442
- @overload
443
- def argmin (
444
- a : ArrayLike ,
445
- axis : SupportsIndex | None = ...,
446
- out : None = ...,
447
- * ,
448
- keepdims : bool = ...,
449
- ) -> Any : ...
450
- @overload
451
- def argmin (
452
- a : ArrayLike ,
453
- axis : SupportsIndex | None ,
454
- out : _ArrayT ,
423
+ axis : int | None = ...,
455
424
* ,
456
- keepdims : bool = ...,
457
- ) -> _ArrayT : ...
425
+ keepdims : bool = ...
426
+ ) -> np .integer : ...
427
+
458
428
@overload
459
429
def argmin (
460
430
a : ArrayLike ,
461
- axis : SupportsIndex | None = ...,
431
+ axis : int | None = ...,
462
432
* ,
463
- out : _ArrayT ,
464
- keepdims : bool = ...,
465
- ) -> _ArrayT : ...
433
+ out : _OutT ,
434
+ keepdims : bool = ...
435
+ ) -> _OutT : ...
466
436
467
- @overload
468
- def searchsorted (
469
- a : ArrayLike ,
470
- v : _ScalarLike_co ,
471
- side : _SortSide = ...,
472
- sorter : _ArrayLikeInt_co | None = ..., # 1D int array
473
- ) -> intp : ...
474
- @overload
437
+ overload
475
438
def searchsorted (
476
439
a : ArrayLike ,
477
440
v : ArrayLike ,
0 commit comments