8000 Use Typevar defaults for `TaskStatus` and `Matcher` (#3019) · python-trio/trio@26cc6ee · GitHub
[go: up one dir, main page]

Skip to content

Commit 26cc6ee

Browse files
authored
Use Typevar defaults for TaskStatus and Matcher (#3019)
* Default TaskStatus to use None if unspecified * Default Matcher to BaseException if unspecified * Update Sphinx logic for new typevar name * Add some type tests for defaulted typevar classes
1 parent b93d8a6 commit 26cc6ee

File tree

5 files changed

+91
-34
lines changed

5 files changed

+91
-34
lines changed

docs/source/conf.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,10 @@ def autodoc_process_signature(
113113
# name.
114114
assert isinstance(obj, property), obj
115115
assert isinstance(obj.fget, types.FunctionType), obj.fget
116-
assert obj.fget.__annotations__["return"] == "type[E]", obj.fget.__annotations__
117-
obj.fget.__annotations__["return"] = "type[~trio.testing._raises_group.E]"
116+
assert (
117+
obj.fget.__annotations__["return"] == "type[MatchE]"
118+
), obj.fget.__annotations__
119+
obj.fget.__annotations__["return"] = "type[~trio.testing._raises_group.MatchE]"
118120
if signature is not None:
119121
signature = signature.replace("~_contextvars.Context", "~contextvars.Context")
120122
if name == "trio.lowlevel.RunVar": # Typevar is not useful here.
@@ -123,13 +125,15 @@ def autodoc_process_signature(
123125
# Strip the type from the union, make it look like = ...
124126
signature = signature.replace(" | type[trio._core._local._NoValue]", "")
125127
signature = signature.replace("<class 'trio._core._local._NoValue'>", "...")
126-
if (
127-
name in ("trio.testing.RaisesGroup", "trio.testing.Matcher")
128-
and "+E" in signature
128+
if name in ("trio.testing.RaisesGroup", "trio.testing.Matcher") and (
129+
"+E" in signature or "+MatchE" in signature
129130
):
130131
# This typevar being covariant isn't handled correctly in some cases, strip the +
131132
# and insert the fully-qualified name.
132133
signature = signature.replace("+E", "~trio.testing._raises_group.E")
134+
signature = signature.replace(
135+
"+MatchE", "~trio.testing._raises_group.MatchE"
136+
)
133137
if "DTLS" in name:
134138
signature = signature.replace("SSL.Context", "OpenSSL.SSL.Context")
135139
# Don't specify PathLike[str] | PathLike[bytes], this is just for humans.

src/trio/_core/_run.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
Final,
2222
NoReturn,
2323
Protocol,
24-
TypeVar,
2524
cast,
2625
overload,
2726
)
@@ -54,12 +53,6 @@
5453
if sys.version_info < (3, 11):
5554
from exceptiongroup import BaseExceptionGroup
5655

57-
FnT = TypeVar("FnT", bound="Callable[..., Any]")
58-
StatusT = TypeVar("StatusT")
59-
StatusT_co = TypeVar("StatusT_co", covariant=True)
60-
StatusT_contra = TypeVar("StatusT_contra", contravariant=True)
61-
RetT = TypeVar("RetT")
62-
6356

6457
if TYPE_CHECKING:
6558
import contextvars
@@ -77,9 +70,19 @@
7770
# for some strange reason Sphinx works with outcome.Outcome, but not Outcome, in
7871
# start_guest_run. Same with types.FrameType in iter_await_frames
7972
import outcome
80-
from typing_extensions import Self, TypeVarTuple, Unpack
73+
from typing_extensions import Self, TypeVar, TypeVarTuple, Unpack
8174

8275
PosArgT = TypeVarTuple("PosArgT")
76+
StatusT = TypeVar("StatusT", default=None)
77+
StatusT_contra = TypeVar("StatusT_contra", contravariant=True, default=None)
78+
else:
79+
from typing import TypeVar
80+
81+
StatusT = TypeVar("StatusT")
82+
StatusT_contra = TypeVar("StatusT_contra", contravariant=True)
83+
84+
FnT = TypeVar("FnT", bound="Callable[..., Any]")
85+
RetT = TypeVar("RetT")
8386

8487

8588
DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: Final = 1000

src/trio/_tests/type_tests/raisesgroup.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ def check_inheritance_and_assignments() -> None:
3737
assert a
3838

3939

40+
def check_matcher_typevar_default(e: Matcher) -> object:
41+
assert e.exception_type is not None
42+
exc: type[BaseException] = e.exception_type
43+
# this would previously pass, as the type would be `Any`
44+
e.exception_type().blah() # type: ignore
45+
return exc # Silence Pyright unused var warning
46+
47+
4048
def check_basic_contextmanager() -> None:
4149
# One level of Group is correctly translated - except it's a BaseExceptionGroup
4250
# instead of an ExceptionGroup.
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""Check that started() can only be called for TaskStatus[None]."""
2+
3+
from trio import TaskStatus
4+
from typing_extensions import assert_type
5+
6+
7+
async def check_status(
8+
none_status_explicit: TaskStatus[None],
9+
none_status_implicit: TaskStatus,
10+
int_status: TaskStatus[int],
11+
) -> None:
12+
assert_type(none_status_explicit, TaskStatus[None])
13+
assert_type(none_status_implicit, TaskStatus[None]) # Default typevar
14+
assert_type(int_status, TaskStatus[int])
15+
16+
# Omitting the parameter is only allowed for None.
17+
none_status_explicit.started()
18+
none_status_implicit.started()
19+
int_status.started() # type: ignore
20+
21+
# Explicit None is allowed.
22+
none_status_explicit.started(None)
23+
none_status_implicit.started(None)
24+
int_status.started(None) # type: ignore
25+
26+
none_status_explicit.started(42) # type: ignore
27+
none_status_implicit.started(42) # type: ignore
28+
int_status.started(42)
29+
int_status.started(True)

src/trio/testing/_raises_group.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
Literal,
1111
Pattern,
1212
Sequence,
13-
TypeVar,
1413
cast,
1514
overload,
1615
)
@@ -26,43 +25,57 @@
2625
import types
2726

2827
from _pytest._code.code import ExceptionChainRepr, ReprExceptionInfo, Traceback
29-
from typing_extensions import TypeGuard
28+
from typing_extensions import TypeGuard, TypeVar
3029

31-
if sys.version_info < (3, 11):
32-
from exceptiongroup import BaseExceptionGroup
30+
MatchE = TypeVar(
31+
"MatchE", bound=BaseException, default=BaseException, covariant=True
32+
)
33+
else:
34+
from typing import TypeVar
3335

36+
MatchE = TypeVar("MatchE", bound=BaseException, covariant=True)
37+
# RaisesGroup doesn't work with a default.
3438
E = TypeVar("E", bound=BaseException, covariant=True)
39+
# These two typevars are special cased in sphinx config to workaround lookup bugs.
40+
41+
if sys.version_info < (3, 11):
42+
from exceptiongroup import BaseExceptionGroup
3543

3644

3745
@final
38-
class _ExceptionInfo(Generic[E]):
46+
class _ExceptionInfo(Generic[MatchE]):
3947
"""Minimal re-implementation of pytest.ExceptionInfo, only used if pytest is not available. Supports a subset of its features necessary for functionality of :class:`trio.testing.RaisesGroup` and :class:`trio.testing.Matcher`."""
4048

41-
_excinfo: tuple[type[E], E, types.TracebackType] | None
49+
_excinfo: tuple[type[MatchE], MatchE, types.TracebackType] | None
4250

43-
def __init__(self, excinfo: tuple[type[E], E, types.TracebackType] | None):
51+
def __init__(
52+
self, excinfo: tuple[type[MatchE], MatchE, types.TracebackType] | None
53+
):
4454
self._excinfo = excinfo
4555

46-
def fill_unfilled(self, exc_info: tuple[type[E], E, types.TracebackType]) -> None:
56+
def fill_unfilled(
57+
self, exc_info: tuple[type[MatchE], MatchE, types.TracebackType]
58+
) -> None:
4759
"""Fill an unfilled ExceptionInfo created with ``for_later()``."""
4860
assert self._excinfo is None, "ExceptionInfo was already filled"
4961
self._excinfo = exc_info
5062

5163
@classmethod
52-
def for_later(cls) -> _ExceptionInfo[E]:
64+
def for_later(cls) -> _ExceptionInfo[MatchE]:
5365
"""Return an unfilled ExceptionInfo."""
5466
return cls(None)
5567

68+
# Note, special cased in sphinx config, since "type" conflicts.
5669
@property
57-
def type(self) -> type[E]:
70+
def type(self) -> type[MatchE]:
5871
"""The exception class."""
5972
assert (
6073
self._excinfo is not None
6174
), ".type can only be used after the context manager exits"
6275
return self._excinfo[0]
6376

6477
@property
65-
def value(self) -> E:
78+
def value(self) -> MatchE:
6679
"""The exception value."""
6780
assert (
6881
self._excinfo is not None
@@ -95,7 +108,7 @@ def getrepr(
95108
showlocals: bool = False,
96109
style: str = "long",
97110
abspath: bool = False,
98-
tbfilter: bool | Callable[[_ExceptionInfo[BaseException]], Traceback] = True,
111+
tbfilter: bool | Callable[[_ExceptionInfo], Traceback] = True,
99112
funcargs: bool = False,
100113
truncate_locals: bool = True,
101114
chain: bool = True,
@@ -135,7 +148,7 @@ def _stringify_exception(exc: BaseException) -> str:
135148

136149

137150
@final
138-
class Matcher(Generic[E]):
151+
class Matcher(Generic[MatchE]):
139152
"""Helper class to be used together with RaisesGroups when you want to specify requirements on sub-exceptions. Only specifying the type is redundant, and it's also unnecessary when the type is a nested `RaisesGroup` since it supports the same arguments.
140153
The type is checked with `isinstance`, and does not need to be an exact match. If that is wanted you can use the ``check`` parameter.
141154
:meth:`trio.testing.Matcher.matches` can also be used standalone to check individual exceptions.
@@ -154,10 +167,10 @@ class Matcher(Generic[E]):
154167
# At least one of the three parameters must be passed.
155168
@overload
156169
def __init__(
157-
self: Matcher[E],
158-
exception_type: type[E],
170+
self: Matcher[MatchE],
171+
exception_type: type[MatchE],
159172
match: str | Pattern[str] = ...,
160-
check: Callable[[E], bool] = ...,
173+
check: Callable[[MatchE], bool] = ...,
161174
): ...
162175

163176
@overload
@@ -174,9 +187,9 @@ def __init__(self, *, check: Callable[[BaseException], bool]): ...
174187

175188
def __init__(
176189
self,
177-
exception_type: type[E] | None = None,
190+
exception_type: type[MatchE] | None = None,
178191
match: str | Pattern[str] | None = None,
179-
check: Callable[[E], bool] | None = None,
192+
check: Callable[[MatchE], bool] | None = None,
180193
):
181194
if exception_type is None and match is None and check is None:
182195
raise ValueError("You must specify at least one parameter to match on.")
@@ -192,7 +205,7 @@ def __init__(
192205
self.match = match
193206
self.check = check
194207

195-
def matches(self, exception: BaseException) -> TypeGuard[E]:
208+
def matches(self, exception: BaseException) -> TypeGuard[MatchE]:
196209
"""Check if an exception matches the requirements of this Matcher.
197210
198211
Examples::
@@ -220,7 +233,7 @@ def matches(self, exception: BaseException) -> TypeGuard[E]:
220233
return False
221234
# If exception_type is None check() accepts BaseException.
222235
# If non-none, we have done an isinstance check above.
223-
if self.check is not None and not self.check(cast(E, exception)):
236+
if self.check is not None and not self.check(cast(MatchE, exception)):
224237
return False
225238
return True
226239

@@ -254,8 +267,8 @@ def __str__(self) -> str:
254267
# We lie to type checkers that we inherit, so excinfo.value and sub-exceptiongroups can be treated as ExceptionGroups
255268
if TYPE_CHECKING:
256269
SuperClass = BaseExceptionGroup
257-
# Inheriting at runtime leads to a series of TypeErrors, so we do not want to do that.
258270
else:
271+
# At runtime, use a redundant Generic base class which effectively gets ignored.
259272
SuperClass = Generic
260273

261274

0 commit comments

Comments
 (0)
0