8000 added comment, overloaded functions are safe · python-trio/flake8-async@065404d · GitHub
[go: up one dir, main page]

Skip to conten 8000 t

Commit 065404d

Browse files
committed
added comment, overloaded functions are safe
1 parent daa7c88 commit 065404d

File tree

3 files changed

+43
-33
lines changed

3 files changed

+43
-33
lines changed

flake8_trio.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,7 @@
1111

1212
import ast
1313
import tokenize
14-
from typing import (
15-
Any,
16-
Collection,
17-
Generator,
18-
Iterable,
19-
List,
20-
Optional,
21-
Tuple,
22-
Type,
23-
Union,
24-
)
14+
from typing import Any, Generator, Iterable, List, Optional, Tuple, Type, Union
2515

2616
# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
2717
__version__ = "22.7.4"
@@ -102,7 +92,7 @@ def get_trio_scope(node: ast.AST, *names: str) -> Optional[TrioScope]:
10292
return None
10393

10494

105-
def has_decorator(decorator_list: List[ast.expr], names: Collection[str]):
95+
def has_decorator(decorator_list: List[ast.expr], *names: str):
10696
for dec in decorator_list:
10797
if (isinstance(dec, ast.Name) and dec.id in names) or (
10898
isinstance(dec, ast.Attribute) and dec.attr in names
@@ -148,7 +138,7 @@ def visit_FunctionDef(
148138
self._yield_is_error = False
149139

150140
# check for @<context_manager_name> and @<library>.<context_manager_name>
151-
if has_decorator(node.decorator_list, context_manager_names):
141+
if has_decorator(node.decorator_list, *context_manager_names):
152142
self._context_manager = True
153143

154144
self.generic_visit(node)
@@ -251,7 +241,7 @@ def visit_FunctionDef(
251241
8000 outer_cm = self._context_manager
252242

253243
# check for @<context_manager_name> and @<library>.<context_manager_name>
254-
if has_decorator(node.decorator_list, context_manager_names):
244+
if has_decorator(node.decorator_list, *context_manager_names):
255245
self._context_manager = True
256246

257247
self.generic_visit(node)
@@ -343,7 +333,8 @@ def __init__(self) -> None:
343333
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
344334
outer = self.all_await
345335

346-
self.all_await = False
336+
# do not require checkpointing if overloading
337+
self.all_await = has_decorator(node.decorator_list, "overload")
347338
self.generic_visit(node)
348339

349340
if not self.all_await:

tests/test_flake8_trio.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -139,29 +139,29 @@ def test_trio106(self):
139139
def test_trio300_301(self):
140140
self.assert_expected_errors(
141141
"trio300_301.py",
142-
make_error(TRIO300, 10, 0),
142+
make_error(TRIO300, 13, 0),
143143
# if
144-
make_error(TRIO300, 15, 0),
145-
make_error(TRIO300, 33, 0),
144+
make_error(TRIO300, 18, 0),
145+
make_error(TRIO300, 36, 0),
146146
# ifexp
147-
make_error(TRIO300, 43, 0),
147+
make_error(TRIO300, 46, 0),
148148
# loops
149-
make_error(TRIO300, 48, 0),
150-
make_error(TRIO300, 53, 0),
151-
make_error(TRIO300, 66, 0),
152-
make_error(TRIO300, 71, 0),
149+
make_error(TRIO300, 51, 0),
150+
make_error(TRIO300, 56, 0),
151+
make_error(TRIO300, 69, 0),
152+
make_error(TRIO300, 74, 0),
153153
# try
154-
make_error(TRIO300, 79, 0),
154+
make_error(TRIO300, 83, 0),
155155
# early return
156-
make_error(TRIO301, 136, 4),
157-
make_error(TRIO301, 141, 8),
156+
make_error(TRIO301, 140, 4),
157+
make_error(TRIO301, 145, 8),
158158
# nested function definition
159-
make_error(TRIO300, 145, 0),
160-
make_error(TRIO300, 155, 4),
161-
make_error(TRIO300, 159, 0),
162-
make_error(TRIO300, 166, 8),
163-
make_error(TRIO300, 164, 0),
164-
make_error(TRIO300, 170, 0),
159+
make_error(TRIO300, 149, 0),
160+
make_error(TRIO300, 159, 4),
161+
make_error(TRIO300, 163, 0),
162+
make_error(TRIO300, 170, 8),
163+
make_error(TRIO300, 168, 0),
164+
make_error(TRIO300, 174, 0),
165165
)
166166

167167

tests/trio300_301.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import typing
2+
from typing import Union, overload
3+
14
import trio
25

36
_ = ""
@@ -76,7 +79,8 @@ async def foo_for_2(): # error: due to not wanting to handle continue/break sem
7679

7780

7881
# try
79-
async def foo_try_1(): # error
82+
# safe only if (try or else) and all except bodies either await or raise
83+
async def foo_try_1(): # error: if foo() raises a ValueError it's not checkpointed
8084
try:
8185
await foo()
8286
except ValueError:
@@ -179,3 +183,18 @@ def foo_normal_func_1():
179183

180184
def foo_normal_func_2():
181185
...
186+
187+
188+
# overload decorator
189+
@overload
190+
async def foo_overload_1(_: bytes):
191+
...
192+
193+
194+
@typing.overload
195+
async def foo_overload_1(_: str):
196+
...
197+
198+
199+
async def foo_overload_1(_: Union[bytes, str]):
200+
await foo()

0 commit comments

Comments
 (0)
0