From 5bb36bbbc0056d9eee8c75907f2a969dec7d6966 Mon Sep 17 00:00:00 2001 From: jorenham Date: Tue, 28 Jan 2025 18:08:36 +0100 Subject: [PATCH] TYP: Avoid upcasting ``float64`` in the set-ops --- numpy/lib/_arraysetops_impl.pyi | 96 +++++-------------- .../typing/tests/data/reveal/arraysetops.pyi | 7 +- 2 files changed, 28 insertions(+), 75 deletions(-) diff --git a/numpy/lib/_arraysetops_impl.pyi b/numpy/lib/_arraysetops_impl.pyi index 3261cdac8cf6..20f2d576bf00 100644 --- a/numpy/lib/_arraysetops_impl.pyi +++ b/numpy/lib/_arraysetops_impl.pyi @@ -10,35 +10,7 @@ from typing import ( from typing_extensions import deprecated import numpy as np -from numpy import ( - generic, - number, - ushort, - ubyte, - uintc, - uint, - ulonglong, - short, - int8, - byte, - intc, - int_, - intp, - longlong, - half, - single, - double, - longdouble, - csingle, - cdouble, - clongdouble, - timedelta64, - datetime64, - object_, - str_, - bytes_, - void, -) +from numpy import generic, number, int8, intp, timedelta64, object_ from numpy._typing import ( ArrayLike, @@ -75,33 +47,17 @@ _NumberType = TypeVar("_NumberType", bound=number[Any]) # Only relevant if two or more arguments are parametrized, (e.g. `setdiff1d`) # which could result in, for example, `int64` and `float64`producing a # `number[_64Bit]` array -_SCTNoCast = TypeVar( - "_SCTNoCast", +_EitherSCT = TypeVar( + "_EitherSCT", np.bool, - ushort, - ubyte, - uintc, - uint, - ulonglong, - short, - byte, - intc, - int_, - longlong, - half, - single, - double, - longdouble, - csingle, - cdouble, - clongdouble, - timedelta64, - datetime64, - object_, - str_, - bytes_, - void, -) + np.int8, np.int16, np.int32, np.int64, np.intp, + np.uint8, np.uint16, np.uint32, np.uint64, np.uintp, + np.float16, np.float32, np.float64, np.longdouble, + np.complex64, np.complex128, np.clongdouble, + np.timedelta64, np.datetime64, + np.bytes_, np.str_, np.void, np.object_, + np.integer, np.floating, np.complexfloating, np.character, +) # fmt: skip class UniqueAllResult(NamedTuple, Generic[_SCT]): values: NDArray[_SCT] @@ -339,11 +295,11 @@ def unique_values(x: ArrayLike, /) -> NDArray[Any]: ... @overload def intersect1d( - ar1: _ArrayLike[_SCTNoCast], - ar2: _ArrayLike[_SCTNoCast], + ar1: _ArrayLike[_EitherSCT], + ar2: _ArrayLike[_EitherSCT], assume_unique: bool = ..., return_indices: L[False] = ..., -) -> NDArray[_SCTNoCast]: ... +) -> NDArray[_EitherSCT]: ... @overload def intersect1d( ar1: ArrayLike, @@ -353,11 +309,11 @@ def intersect1d( ) -> NDArray[Any]: ... @overload def intersect1d( - ar1: _ArrayLike[_SCTNoCast], - ar2: _ArrayLike[_SCTNoCast], + ar1: _ArrayLike[_EitherSCT], + ar2: _ArrayLike[_EitherSCT], assume_unique: bool = ..., return_indices: L[True] = ..., -) -> tuple[NDArray[_SCTNoCast], NDArray[intp], NDArray[intp]]: ... +) -> tuple[NDArray[_EitherSCT], NDArray[intp], NDArray[intp]]: ... @overload def intersect1d( ar1: ArrayLike, @@ -368,10 +324,10 @@ def intersect1d( @overload def setxor1d( - ar1: _ArrayLike[_SCTNoCast], - ar2: _ArrayLike[_SCTNoCast], + ar1: _ArrayLike[_EitherSCT], + ar2: _ArrayLike[_EitherSCT], assume_unique: bool = ..., -) -> NDArray[_SCTNoCast]: ... +) -> NDArray[_EitherSCT]: ... @overload def setxor1d( ar1: ArrayLike, @@ -400,9 +356,9 @@ def in1d( @overload def union1d( - ar1: _ArrayLike[_SCTNoCast], - ar2: _ArrayLike[_SCTNoCast], -) -> NDArray[_SCTNoCast]: ... + ar1: _ArrayLike[_EitherSCT], + ar2: _ArrayLike[_EitherSCT], +) -> NDArray[_EitherSCT]: ... @overload def union1d( ar1: ArrayLike, @@ -411,10 +367,10 @@ def union1d( @overload def setdiff1d( - ar1: _ArrayLike[_SCTNoCast], - ar2: _ArrayLike[_SCTNoCast], + ar1: _ArrayLike[_EitherSCT], + ar2: _ArrayLike[_EitherSCT], assume_unique: bool = ..., -) -> NDArray[_SCTNoCast]: ... +) -> NDArray[_EitherSCT]: ... @overload def setdiff1d( ar1: ArrayLike, diff --git a/numpy/typing/tests/data/reveal/arraysetops.pyi b/numpy/typing/tests/data/reveal/arraysetops.pyi index 33793f8deebc..eabc7677cde9 100644 --- a/numpy/typing/tests/data/reveal/arraysetops.pyi +++ b/numpy/typing/tests/data/reveal/arraysetops.pyi @@ -2,10 +2,7 @@ from typing import Any import numpy as np import numpy.typing as npt -from numpy.lib._arraysetops_impl import ( - UniqueAllResult, UniqueCountsResult, UniqueInverseResult -) -from numpy._typing import _64Bit +from numpy.lib._arraysetops_impl import UniqueAllResult, UniqueCountsResult, UniqueInverseResult from typing_extensions import assert_type @@ -28,7 +25,7 @@ assert_type(np.intersect1d(AR_M, AR_M, assume_unique=True), npt.NDArray[np.datet assert_type(np.intersect1d(AR_f8, AR_i8), npt.NDArray[Any]) assert_type( np.intersect1d(AR_f8, AR_f8, return_indices=True), - tuple[npt.NDArray[np.floating[_64Bit]], npt.NDArray[np.intp], npt.NDArray[np.intp]], + tuple[npt.NDArray[np.float64], npt.NDArray[np.intp], npt.NDArray[np.intp]], ) assert_type(np.setxor1d(AR_i8, AR_i8), npt.NDArray[np.int64])