8000 TYP,ENH: Improve typing with the help of `ParamSpec` by BvB93 · Pull Request #20885 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

TYP,ENH: Improve typing with the help of ParamSpec #20885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
8000 Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions numpy/lib/function_base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ def asarray_chkfinite(
order: _OrderKACF = ...,
) -> NDArray[Any]: ...

# TODO: Use PEP 612 `ParamSpec` once mypy supports `Concatenate`
# xref python/mypy#8645
@overload
def piecewise(
x: _ArrayLike[_SCT],
Expand Down
2 changes: 2 additions & 0 deletions numpy/lib/shape_base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def put_along_axis(
axis: None | int,
) -> None: ...

# TODO: Use PEP 612 `ParamSpec` once mypy supports `Concatenate`
# xref python/mypy#8645
@overload
def apply_along_axis(
func1d: Callable[..., _ArrayLike[_SCT]],
Expand Down
32 changes: 17 additions & 15 deletions numpy/testing/_private/utils.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ from typing import (
Final,
SupportsIndex,
)
from typing_extensions import ParamSpec

from numpy import generic, dtype, number, object_, bool_, _FloatValue
from numpy.typing import (
Expand All @@ -36,6 +37,7 @@ from unittest.case import (
SkipTest as SkipTest,
)

_P = ParamSpec("_P")
_T = TypeVar("_T")
_ET = TypeVar("_ET", bound=BaseException)
_FT = TypeVar("_FT", bound=Callable[..., Any])
Expand Down Expand Up @@ -254,10 +256,10 @@ def raises(*args: type[BaseException]) -> Callable[[_FT], _FT]: ...
@overload
def assert_raises( # type: ignore
expected_exception: type[BaseException] | tuple[type[BaseException], ...],
callable: Callable[..., Any],
callable: Callable[_P, Any],
/,
*args: Any,
**kwargs: Any,
*args: _P.args,
**kwargs: _P.kwargs,
) -> None: ...
@overload
def assert_raises(
Expand All @@ -270,10 +272,10 @@ def assert_raises(
def assert_raises_regex(
expected_exception: type[BaseException] | tuple[type[BaseException], ...],
expected_regex: str | bytes | Pattern[Any],
callable: Callable[..., Any],
callable: Callable[_P, Any],
/,
*args: Any,
**kwargs: Any,
*args: _P.args,
**kwargs: _P.kwargs,
) -> None: ...
@overload
def assert_raises_regex(
Expand Down Expand Up @@ -336,20 +338,20 @@ def assert_warns(
@overload
def assert_warns(
warning_class: type[Warning],
func: Callable[..., _T],
func: Callable[_P, _T],
/,
*args: Any,
**kwargs: Any,
*args: _P.args,
**kwargs: _P.kwargs,
) -> _T: ...

@overload
def assert_no_warnings() -> contextlib._GeneratorContextManager[None]: ...
@overload
def assert_no_warnings(
func: Callable[..., _T],
func: Callable[_P, _T],
/,
*args: Any,
**kwargs: Any,
*args: _P.args,
**kwargs: _P.kwargs,
) -> _T: ...

@overload
Expand Down Expand Up @@ -384,10 +386,10 @@ def temppath(
def assert_no_gc_cycles() -> contextlib._GeneratorContextManager[None]: ...
@overload
def assert_no_gc_cycles(
func: Callable[..., Any],
func: Callable[_P, Any],
/,
*args: Any,
**kwargs: Any,
*args: _P.args,
**kwargs: _P.kwargs,
) -> None: ...

def break_cycles() -> None: ...
2 changes: 2 additions & 0 deletions numpy/typing/tests/data/fail/testing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,7 @@ np.testing.assert_array_max_ulp(AR_U, AR_U) # E: incompatible type

np.testing.assert_warns(warning_class=RuntimeWarning, func=func) # E: No overload variant
np.testing.assert_no_warnings(func=func) # E: No overload variant
np.testing.assert_no_warnings(func, None) # E: Too many arguments
np.testing.assert_no_warnings(func, test=None) # E: Unexpected keyword argument

np.testing.assert_no_gc_cycles(func=func) # E: No overload variant
4 changes: 4 additions & 0 deletions numpy/typing/tests/data/reveal/testing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,12 @@ reveal_type(np.testing.assert_array_max_ulp(AR_i8, AR_f8, dtype=np.float32)) #
reveal_type(np.testing.assert_warns(RuntimeWarning)) # E: _GeneratorContextManager[None]
reveal_type(np.testing.assert_warns(RuntimeWarning, func3, 5)) # E: bool

def func4(a: int, b: str) -> bool: ...

reveal_type(np.testing.assert_no_warnings()) # E: _GeneratorContextManager[None]
reveal_type(np.testing.assert_no_warnings(func3, 5)) # E: bool
reveal_type(np.testing.assert_no_warnings(func4, a=1, b="test")) # E: bool
reveal_type(np.testing.assert_no_warnings(func4, 1, "test")) # E: bool

reveal_type(np.testing.tempdir("test_dir")) # E: _GeneratorContextManager[builtins.str]
reveal_type(np.testing.tempdir(prefix=b"test")) # E: _GeneratorContextManager[builtins.bytes]
Expand Down
0