8000 Merge pull request #5 from jakkdl/trio102 · python-trio/flake8-async@b10621e · GitHub
[go: up one dir, main page]

Skip to content

Commit b10621e

Browse files
authored
Merge pull request #5 from jakkdl/trio102
2 parents adfaa56 + 70da625 commit b10621e

File tree

7 files changed

+338
-34
lines changed

7 files changed

+338
-34
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
# Changelog
22
*[CalVer, YY.month.patch](https://calver.org/)*
33

4+
## 22.7.3
5+
- Added TRIO102 check for unsafe checkpoints inside `finally:` blocks
6+
47
## 22.7.2
58
- Avoid `TRIO100` false-alarms on cancel scopes containing `async for` or `async with`.
69

710
## 22.7.1
8-
- Initial release
11+
- Initial release with TRIO100 and TRIO101

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@ pip install flake8-trio
2222
context does not contain any `await` statements. This makes it pointless, as
2323
the timeout can only be triggered by a checkpoint.
2424
- **TRIO101** `yield` inside a nursery or cancel scope is only safe when implementing a context manager - otherwise, it breaks exception handling.
25+
- **TRIO102** it's unsafe to await inside `finally:` unless you use a shielded
26+
cancel scope with a timeout"

flake8_trio.py

Lines changed: 173 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111

1212
import ast
1313
import tokenize
14-
from typing import Any, Generator, List, Optional, Set, Tuple, Type, Union
14+
from typing import Any, Collection, Generator, List, Optional, Tuple, Type, Union
1515

1616
# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
17-
__version__ = "22.7.2"
17+
__version__ = "22.7.3"
1818

1919

2020
Error = Tuple[int, int, str, Type[Any]]
@@ -36,66 +36,206 @@ def make_error(error: str, lineno: int, col: int, *args: Any, **kwargs: Any) ->
3636
return (lineno, col, error.format(*args, **kwargs), type(Plugin))
3737

3838

39-
def is_trio_call(node: ast.AST, *names: str) -> Optional[str]:
39+
class TrioScope:
40+
def __init__(self, node: ast.Call, funcname: str, packagename: str):
41+
self.node = node
42+
self.funcname = funcname
43+
self.packagename = packagename
44+
self.variable_name: Optional[str] = None
45+
self.shielded: bool = False
46+
self.has_timeout: bool = False
47+
48+
if self.funcname == "CancelScope":
49+
for kw in node.keywords:
50+
# Only accepts constant values
51+
if kw.arg == "shield" and isinstance(kw.value, ast.Constant):
52+
self.shielded = kw.value.value
53+
# sets to True even if timeout is explicitly set to inf
54+
if kw.arg == "deadline":
55+
self.has_timeout = True
56+
else:
57+
self.has_timeout = True
58+
59+
def __str__(self):
60+
# Not supporting other ways of importing trio
61+
# if self.packagename is None:
62+
# return self.funcname
63+
return f"{self.packagename}.{self.funcname}"
64+
65+
66+
def get_trio_scope(node: ast.AST, *names: str) -> Optional[TrioScope]:
4067
if (
4168
isinstance(node, ast.Call)
4269
and isinstance(node.func, ast.Attribute)
4370
and isinstance(node.func.value, ast.Name)
4471
and node.func.value.id == "trio"
4572
and node.func.attr in names
4673
):
47-
return "trio." + node.func.attr
74+
# return "trio." + node.func.attr
75+
return TrioScope(node, node.func.attr, node.func.value.id)
4876
return None
4977

5078

79+
def has_decorator(decorator_list: List[ast.expr], names: Collection[str]):
80+
for dec in decorator_list:
81+
if (isinstance(dec, ast.Name) and dec.id in names) or (
82+
isinstance(dec, ast.Attribute) and dec.attr in names
83+
):
84+
return True
85+
return False
86+
87+
88+
class Visitor102(ast.NodeVisitor):
89+
def __init__(self) -> None:
90+
super().__init__()
91+
self.problems: List[Error] = []
92+
self._inside_finally: bool = False
93+
self._scopes: List[TrioScope] = []
94+
self._context_manager = False
95+
96+
def visit_Assign(self, node: ast.Assign) -> None:
97+
# checks for <scopename>.shield = [True/False]
98+
if self._scopes and len(node.targets) == 1:
99+
last_scope = self._scopes[-1]
100+
target = node.targets[0]
101+
if (
102+
last_scope.variable_name is not None
103+
and isinstance(target, ast.Attribute)
104+
and isinstance(target.value, ast.Name)
105+
and target.value.id == last_scope.variable_name
106+
and target.attr == "shield"
107+
and isinstance(node.value, ast.Constant)
108+
):
109+
last_scope.shielded = node.value.value
110+
self.generic_visit(node)
111+
112+
def visit_Await(self, node: ast.Await) -> None:
113+
self.check_for_trio102(node)
114+
self.generic_visit(node)
115+
116+
def visit_With(self, node: Union[ast.With, ast.AsyncWith]) -> None:
117+
trio_scope = None
118+
119+
# Check for a `with trio.<scope_creater>`
120+
for item in node.items:
121+
trio_scope = get_trio_scope(
122+
item.context_expr, "open_nursery", *cancel_scope_names
123+
)
124+
if trio_scope is not None:
125+
# check if it's saved in a variable
126+
if isinstance(item.optional_vars, ast.Name):
127+
trio_scope.variable_name = item.optional_vars.id
128+
break
129+
130+
if trio_scope is not None:
131+
self._scopes.append(trio_scope)
132+
133+
self.generic_visit(node)
134+
135+
if trio_scope is not None:
136+
self._scopes.pop()
137+
138+
def visit_AsyncWith(self, node: ast.AsyncWith) -> None:
139+
self.check_for_trio102(node)
140+
self.visit_With(node)
141+
142+
def visit_AsyncFor(self, node: ast.AsyncFor) -> None:
143+
self.check_for_trio102(node)
144+
self.generic_visit(node)
145+
146+
def visit_FunctionDef(
147+
self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
148+
) -> None:
149+
outer_cm = self._context_manager
150+
151+
# check for @<context_manager_name> and @<library>.<context_manager_name>
152+
if has_decorator(node.decorator_list, context_manager_names):
153+
self._context_manager = True
154+
155+
self.generic_visit(node)
156+
self._context_manager = outer_cm
157+
158+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
159+
self.visit_FunctionDef(node)
160+
161+
def visit_Try(self, node: ast.Try) -> None:
162+
# There's no visit_Finally, so we need to manually visit the Try fields.
163+
# It's important to do self.visit instead of self.generic_visit since
164+
# the nodes in the fields might be registered elsewhere in this class.
165+
for item in (*node.body, *node.handlers, *node.orelse):
166+
self.visit(item)
167+
168+
outer = self._inside_finally
169+
outer_scopes = self._scopes
170+
171+
self._scopes = []
172+
self._inside_finally = True
173+
174+
for item in node.finalbody:
175+
self.visit(item)
176+
177+
self._scopes = outer_scopes
178+
self._inside_finally = outer
179+
180+
def check_for_trio102(self, node: Union[ast.Await, ast.AsyncFor, ast.AsyncWith]):
181+
# if we're inside a finally, and not inside a context_manager, and we're not
182+
# inside a scope that doesn't have both a timeout and shield
183+
if (
184+
self._inside_finally
185+
and not self._context_manager
186+
and not any(scope.has_timeout and scope.shielded for scope in self._scopes)
187+
):
188+
self.problems.append(make_error(TRIO102, node.lineno, node.col_offset))
189+
190+
51191
class Visitor(ast.NodeVisitor):
52192
def __init__(self) -> None:
53193
super().__init__()
54194
self.problems: List[Error] = []
55-
self.safe_yields: Set[ast.Yield] = set()
56195
self._yield_is_error = False
57196
self._context_manager = False
58197

59-
def visit_generic_with(self, node: Union[ast.With, ast.AsyncWith]):
198+
def visit_With(self, node: Union[ast.With, ast.AsyncWith]) -> None:
60199
self.check_for_trio100(node)
61200

62-
outer = self._yield_is_error
63-
if not self._context_manager and any(
64-
is_trio_call(item, "open_nursery", *cancel_scope_names)
65-
for item in (i.context_expr for i in node.items)
66-
):
67-
self._yield_is_error = True
201+
outer_yie = self._yield_is_error
202+
203+
# Check for a `with trio.<scope_creater>`
204+
if not self._context_manager:
205+
for item in (i.context_expr for i in node.items):
206+
if (
207+
get_trio_scope(item, "open_nursery", *cancel_scope_names)
208+
is not None
209+
):
210+
self._yield_is_error = True
211+
break
68212

69213
self.generic_visit(node)
70-
self._yield_is_error = outer
71214

72-
def visit_With(self, node: ast.With) -> None:
73-
self.visit_generic_with(node)
215+
# reset yield_is_error
216+
self._yield_is_error = outer_yie
74217

75218
def visit_AsyncWith(self, node: ast.AsyncWith) -> None:
76-
self.visit_generic_with(node)
219+
self.visit_With(node)
77220

78-
def visit_generic_FunctionDef(
221+
def visit_FunctionDef(
79222
self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
80-
):
223+
) -> None:
81224
outer_cm = self._context_manager
82225
outer_yie = self._yield_is_error
83226
self._yield_is_error = False
84-
if any(
85-
(isinstance(d, ast.Name) and d.id in context_manager_names)
86-
or (isinstance(d, ast.Attribute) and d.attr in context_manager_names)
87-
for d in node.decorator_list
88-
):
227+
228+
# check for @<context_manager_name> and @<library>.<context_manager_name>
229+
if has_decorator(node.decorator_list, context_manager_names):
89230
self._context_manager = True
231+
90232
self.generic_visit(node)
233+
91234
self._context_manager = outer_cm
92235
self._yield_is_error = outer_yie
93236

94-
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
95-
self.visit_generic_FunctionDef(node)
96-
97237
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
98-
self.visit_generic_FunctionDef(node)
238+
self.visit_FunctionDef(node)
99239

100240
def visit_Yield(self, node: ast.Yield) -> None:
101241
if self._yield_is_error:
@@ -106,7 +246,7 @@ def visit_Yield(self, node: ast.Yield) -> None:
106246
def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]) -> None:
107247
# Context manager with no `await` call within
108248
for item in (i.context_expr for i in node.items):
109-
call = is_trio_call(item, *cancel_scope_names)
249+
call = get_trio_scope(item, *cancel_scope_names)
110250
if call and not any(
111251
isinstance(x, checkpoint_node_types) for x in ast.walk(node)
112252
):
@@ -129,10 +269,12 @@ def from_filename(cls, filename: str) -> "Plugin":
129269
return cls(ast.parse(source))
130270

131271
def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]:
132-
visitor = Visitor()
133-
visitor.visit(self._tree)
134-
yield from visitor.problems
272+
for v in (Visitor, Visitor102):
273+
visitor = v()
274+
visitor.visit(self._tree)
275+
yield from visitor.problems
135276

136277

137278
TRIO100 = "TRIO100: {} context contains no checkpoints, add `await trio.sleep(0)`"
138279
TRIO101 = "TRIO101: yield inside a nursery or cancel scope is only safe when implementing a context manager - otherwise, it breaks exception handling"
280+
TRIO102 = "TRIO102: it's unsafe to await inside `finally:` unless you use a shielded cancel scope with a timeout"

tests/test_flake8_trio.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from hypothesis import HealthCheck, given, settings
1010
from hypothesmith import from_grammar, from_node
1111

12-
from flake8_trio import TRIO100, TRIO101, Error, Plugin, Visitor, make_error
12+
from flake8_trio import TRIO100, TRIO101, TRIO102, Error, Plugin, Visitor, make_error
1313

1414

1515
class Flake8TrioTestCase(unittest.TestCase):
@@ -48,6 +48,33 @@ def test_trio101(self):
4848
make_error(TRIO101, 15, 8),
4949
make_error(TRIO101, 27, 8),
5050
make_error(TRIO101, 38, 8),
51+
make_error(TRIO101, 59, 8),
52+
)
53+
54+
def test_trio102(self):
55+
self.assert_expected_errors(
56+
"trio102.py",
57+
make_error(TRIO102, 24, 8),
58+
make_error(TRIO102, 30, 12),
59+
make_error(TRIO102, 36, 12),
60+
make_error(TRIO102, 62, 12),
61+
make_error(TRIO102, 70, 12),
62+
make_error(TRIO102, 74, 12),
63+
make_error(TRIO102, 76, 12),
64+
make_error(TRIO102, 80, 12),
65+
make_error(TRIO102, 82, 12),
66+
make_error(TRIO102, 84, 12),
67+
make_error(TRIO102, 88, 12),
68+
make_error(TRIO102, 92, 8),
69+
make_error(TRIO102, 94, 8),
70+
make_error(TRIO102, 101, 12),
71+
)
72+
73+
@unittest.skipIf(sys.version_info < (3, 9), "requires 3.9+")
74+
def test_trio102_py39(self):
75+
self.assert_expected_errors(
76+
"trio102_py39.py",
77+
make_error(TRIO102, 15, 12),
5178
)
5279

5380

tests/trio101.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import contextlib
22
import contextlib as bla
3-
from contextlib import asynccontextmanager, contextmanager
3+
from contextlib import asynccontextmanager, contextmanager, contextmanager as blahabla
44

55
import trio
66

@@ -51,3 +51,9 @@ async def foo7():
5151
def foo8():
5252
with trio.open_nursery() as _:
5353
yield 1 # safe
54+
55+
56+
@blahabla
57+
def foo9():
58+
with trio.open_nursery() as _:
59+
yield 1 # error

0 commit comments

Comments
 (0)
0