8000 Merge pull request #27869 from PTUsumit/fix-interp-type · melissawm/numpy@7f93cf4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7f93cf4

Browse files
authored
Merge pull request numpy#27869 from PTUsumit/fix-interp-type
TYP: Fix ``np.interp`` signature for scalar types
2 parents c412bed + bef5cf1 commit 7f93cf4

File tree

2 files changed

+91
-15
lines changed

2 files changed

+91
-15
lines changed

numpy/lib/_function_base_impl.pyi

Lines changed: 82 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Sequence, Iterator, Callable, Iterable
1+
from collections.abc import Sequence, Callable, Iterable
22
from typing import (
33
Concatenate,
44
Literal as L,
@@ -15,8 +15,9 @@ from typing import (
1515
)
1616
from typing_extensions import deprecated
1717

18+
import numpy as np
1819
from numpy import (
19-
vectorize as vectorize,
20+
vectorize,
2021
generic,
2122
integer,
2223
floating,
@@ -35,19 +36,22 @@ from numpy._typing import (
3536
NDArray,
3637
ArrayLike,
3738
DTypeLike,
38-
_ShapeLike,
39-
_ScalarLike_co,
40-
_DTypeLike,
4139
_ArrayLike,
40+
_DTypeLike,
41+
_ShapeLike,
4242
_ArrayLikeBool_co,
4343
_ArrayLikeInt_co,
4444
_ArrayLikeFloat_co,
4545
_ArrayLikeComplex_co,
46+
_ArrayLikeNumber_co,
4647
_ArrayLikeTD64_co,
4748
_ArrayLikeDT64_co,
4849
_ArrayLikeObject_co,
4950
_FloatLike_co,
5051
_ComplexLike_co,
52+
_NumberLike_co,
53+
_ScalarLike_co,
54+
_NestedSequence
5155
)
5256

5357
__all__ = [
@@ -303,24 +307,87 @@ def diff(
303307
append: ArrayLike = ...,
304308
) -> NDArray[Any]: ...
305309

306-
@overload
310+
@overload # float scalar
307311
def interp(
308-
x: _ArrayLikeFloat_co,
312+
x: _FloatLike_co,
313+
xp: _ArrayLikeFloat_co,
314+
fp: _ArrayLikeFloat_co,
315+
left: _FloatLike_co | None = None,
316+
right: _FloatLike_co | None = None,
317+
period: _FloatLike_co | None = None,
318+
) -> float64: ...
319+
@overload # float array
320+
def interp(
321+
x: NDArray[floating | integer | np.bool] | _NestedSequence[_FloatLike_co],
309322
xp: _ArrayLikeFloat_co,
310323
fp: _ArrayLikeFloat_co,
311-
left: None | _FloatLike_co = ...,
312-
right: None | _FloatLike_co = ...,
313-
period: None | _FloatLike_co = ...,
324+
left: _FloatLike_co | None = None,
325+
right: _FloatLike_co | None = None,
326+
period: _FloatLike_co | None = None,
314327
) -> NDArray[float64]: ...
315-
@overload
328+
@overload # float scalar or array
316329
def interp(
317330
x: _ArrayLikeFloat_co,
318331
xp: _ArrayLikeFloat_co,
319-
fp: _ArrayLikeComplex_co,
320-
left: None | _ComplexLike_co = ...,
321-
right: None | _ComplexLike_co = ...,
322-
period: None | _FloatLike_co = ...,
332+
fp: _ArrayLikeFloat_co,
333+
left: _FloatLike_co | None = None,
334+
right: _FloatLike_co | None = None,
335+
period: _FloatLike_co | None = None,
336+
) -> NDArray[float64] | float64: ...
337+
@overload # complex scalar
338+
def interp(
339+
x: _FloatLike_co,
340+
xp: _ArrayLikeFloat_co,
341+
fp: _ArrayLike[complexfloating],
342+
left: _NumberLike_co | None = None,
343+
right: _NumberLike_co | None = None,
344+
period: _FloatLike_co | None = None,
345+
) -> complex128: ...
346+
@overload # complex or float scalar
347+
def interp(
348+
x: _FloatLike_co,
349+
xp: _ArrayLikeFloat_co,
350+
fp: Sequence[complex | complexfloating],
351+
left: _NumberLike_co | None = None,
352+
right: _NumberLike_co | None = None,
353+
period: _FloatLike_co | None = None,
354+
) -> complex128 | float64: ...
355+
@overload # complex array
356+
def interp(
357+
x: NDArray[floating | integer | np.bool] | _NestedSequence[_FloatLike_co],
358+
xp: _ArrayLikeFloat_co,
359+
fp: _ArrayLike[complexfloating],
360+
left: _NumberLike_co | None = None,
361+
right: _NumberLike_co | None = None,
362+
period: _FloatLike_co | None = None,
323363
) -> NDArray[complex128]: ...
364+
@overload # complex or float array
365+
def interp(
366+
x: NDArray[floating | integer | np.bool] | _NestedSequence[_FloatLike_co],
367+
xp: _ArrayLikeFloat_co,
368+
fp: Sequence[complex | complexfloating],
369+
left: _NumberLike_co | None = None,
370+
right: _NumberLike_co | None = None,
371+
period: _FloatLike_co | None = None,
372+
) -> NDArray[complex128 | float64]: ...
373+
@overload # complex scalar or array
374+
def interp(
375+
x: _ArrayLikeFloat_co,
376+
xp: _ArrayLikeFloat_co,
377+
fp: _ArrayLike[complexfloating],
378+
left: _NumberLike_co | None = None,
379+
right: _NumberLike_co | None = None,
380+
period: _FloatLike_co | None = None,
381+
) -> NDArray[complex128] | complex128: ...
382+
@overload # complex or float scalar or array
383+
def interp(
384+
x: _ArrayLikeFloat_co,
385+
xp: _ArrayLikeFloat_co,
386+
fp: _ArrayLikeNumber_co,
387+
left: _NumberLike_co | None = None,
388+
right: _NumberLike_co | None = None,
389+
period: _FloatLike_co | None = None,
390+
) -> NDArray[complex128 | float64] | complex128 | float64: ...
324391

325392
@overload
326393
def angle(z: _ComplexLike_co, deg: bool = ...) -> floating[Any]: ...

numpy/typing/tests/data/reveal/lib_function_base.pyi

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,15 @@ assert_type(np.diff("bob", n=0), str)
9494
assert_type(np.diff(AR_f8, axis=0), npt.NDArray[Any])
9595
assert_type(np.diff(AR_LIKE_f8, prepend=1.5), npt.NDArray[Any])
9696

97+
assert_type(np.interp(1, [1], AR_f8), np.float64)
98+
assert_type(np.interp(1, [1], [1]), np.float64)
99+
assert_type(np.interp(1, [1], AR_c16), np.complex128)
100+
assert_type(np.interp(1, [1], [1j]), np.complex128) # pyright correctly infers `complex128 | float64`
101+
assert_type(np.interp([1], [1], AR_f8), npt.NDArray[np.float64])
102+
assert_type(np.interp([1], [1], [1]), npt.NDArray[np.float64])
103+
assert_type(np.interp([1], [1], AR_c16), npt.NDArray[np.complex128])
104+
assert_type(np.interp([1], [1], [1j]), npt.NDArray[np.complex128]) # pyright correctly infers `NDArray[complex128 | float64]`
105+
97106
assert_type(np.angle(f8), np.floating[Any])
98107
assert_type(np.angle(AR_f8), npt.NDArray[np.floating[Any]])
99108
assert_type(np.angle(AR_c16, deg=True), npt.NDArray[np.floating[Any]])

0 commit comments

Comments
 (0)
0