8000 Merge pull request #22 from jakkdl/11_nursery_outermost_async_context… · python-trio/flake8-async@3f8e4f3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3f8e4f3

Browse files
authored
Merge pull request #22 from jakkdl/11_nursery_outermost_async_context_manager
2 parents d8acd4d + d655eb7 commit 3f8e4f3

File tree

5 files changed

+393
-40
lines changed

5 files changed

+393
-40
lines changed

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# Changelog
22
*[CalVer, YY.month.patch](https://calver.org/)*
33

4-
## Future
5-
- add TRIO112, nursery body with only a call to `nursery.start[_soon]` and not passing itself as a parameter can be replaced with a regular function call.
4+
## 22.8.5
5+
- Add TRIO111: Variable, from context manager opened inside nursery, passed to `start[_soon]` might be invalidly accesed while in use, due to context manager closing before the nursery. This is usually a bug, and nurseries should generally be the inner-most context manager.
6+
- Add TRIO112: this single-task nursery could be replaced by awaiting the function call directly.
67

78
## 22.8.4
89
- Fix TRIO108 raising errors on yields in some sync code.

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,5 @@ pip install flake8-trio
3333
Checkpoints are `await`, `async for`, and `async with` (on one of enter/exit).
3434
- **TRIO109**: Async function definition with a `timeout` parameter - use `trio.[fail/move_on]_[after/at]` instead
3535
- **TRIO110**: `while <condition>: await trio.sleep()` should be replaced by a `trio.Event`.
36+
- **TRIO111**: Variable, from context manager opened inside nursery, passed to `start[_soon]` might be invalidly accesed while in use, due to context manager closing before the nursery. This is usually a bug, and nurseries should generally be the inner-most context manager.
3637
- **TRIO112**: nursery body with only a call to `nursery.start[_soon]` and not passing itself as a parameter can be replaced with a regular function call.

flake8_trio.py

Lines changed: 124 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
)
2626

2727
# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
28-
__version__ = "22.8.4"
28+
__version__ = "22.8.5"
2929

3030

3131
Error_codes = {
@@ -55,7 +55,12 @@
5555
"`trio.[fail/move_on]_[after/at]` instead"
5656
),
5757
"TRIO110": "`while <condition>: await trio.sleep()` should be replaced by a `trio.Event`.",
58-
"TRIO112": "Redundant nursery {}, consider replacing with a regular function call",
58+
"TRIO111": (
59+
"variable {2} is usable within the context manager on line {0}, but that "
60+
"will close before nursery opened on line {1} - this is usually a bug. "
61+
"Nurseries should generally be the inner-most context manager."
62+
),
63+
"TRIO112": "Redundant nursery {}, consider replacing with directly awaiting the function call",
5964
}
6065

6166

@@ -162,10 +167,18 @@ def error(self, error: str, node: HasLineCol, *args: object):
162167
if not self.suppress_errors:
163168
self._problems.append(Error(error, node.lineno, node.col_offset, *args))
164169

165-
def get_state(self, *attrs: str) -> Dict[str, Any]:
170+
def get_state(self, *attrs: str, copy: bool = False) -> Dict[str, Any]:
166171
if not attrs:
167172
attrs = tuple(self.__dict__.keys())
168-
return {attr: getattr(self, attr) for attr in attrs if attr != "_problems"}
173+
res: Dict[str, Any] = {}
174+
for attr in attrs:
175+
if attr == "_problems":
176+
continue
177+
value = getattr(self, attr)
178+
if copy and hasattr(value, "copy"):
179+
value = value.copy()
180+
res[attr] = value
181+
return res
169182

170183
def set_state(self, attrs: Dict[str, Any], copy: bool = False):
171184
for attr, value in attrs.items():
@@ -187,37 +200,68 @@ def has_decorator(decorator_list: List[ast.expr], *names: str):
187200
return False
188201

189202

190-
# handles 100, 101, 106, 109, 110
203+
# handles 100, 101, 106, 109, 110, 111, 112
191204
class VisitorMiscChecks(Flake8TrioVisitor):
205+
class NurseryCall(NamedTuple):
206+
stack_index: int
207+
name: str
208+
209+
class TrioContextManager(NamedTuple):
210+
lineno: int
211+
name: str
212+
is_nursery: bool
213+
192214
def __init__(self):
193215
super().__init__()
194216

195-
# variables only used for 101
217+
# 101
196218
self._yield_is_error = False
197219
self._safe_decorator = False
198220

199-
# ---- 100, 101 ----
221+
# 111
222+
self._context_managers: List[VisitorMiscChecks.TrioContextManager] = []
223+
self._nursery_call: Optional[VisitorMiscChecks.NurseryCall] = None
224+
225+
self.defaults = self.get_state(copy=True)
226+
227+
# ---- 100, 101, 111, 112 ----
200228
def visit_With(self, node: Union[ast.With, ast.AsyncWith]):
201-
# 100
202229
self.check_for_trio100(node)
203230
self.check_for_trio112(node)
204231

205-
# 101 for rest of function
206-
outer = self.get_state("_yield_is_error")
232+
outer = self.get_state("_yield_is_error", "_context_managers", copy=True)
207233

208-
# Check for a `with trio.<scope_creater>`
209-
if not self._safe_decorator:
210-
for item in (i.context_expr for i in node.items):
211-
if (
212-
get_matching_call(item, "open_nursery", *cancel_scope_names)
213-
is not None
214-
):
215-
self._yield_is_error = True
216-
break
234+
for item in node.items:
235+
# 101
236+
# if there's no safe decorator,
237+
# and it's not yet been determined that yield is error
238+
# and this withitem opens a cancelscope:
239+
# then yielding is unsafe
240+
if (
241+
not self._safe_decorator
242+
and not self._yield_is_error
243+
and get_matching_call(
244+
item.context_expr, "open_nursery", *cancel_scope_names
245+
)
246+
is not None
247+
):
248+
self._yield_is_error = True
217249

218-
self.generic_visit(node)
250+
# 111
251+
# if a withitem is saved in a variable,
252+
# push its line, variable, and whether it's a trio nursery
253+
# to the _context_managers stack,
254+
if isinstance(item.optional_vars, ast.Name):
255+
self._context_managers.append(
256+
self.TrioContextManager(
257+
item.context_expr.lineno,
258+
item.optional_vars.id,
259+
get_matching_call(item.context_expr, "open_nursery")
260+
is not None,
261+
)
262+
)
219263

220-
# reset yield_is_error
264+
self.generic_visit(node)
221265
self.set_state(outer)
222266

223267
visit_AsyncWith = visit_With
@@ -236,7 +280,7 @@ def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]):
236280
# ---- 101 ----
237281
def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
238282
outer = self.get_state()
239-
self._yield_is_error = False
283+
self.set_state(self.defaults, copy=True)
240284

241285
# check for @<context_manager_name> and @<library>.<context_manager_name>
242286
if has_decorator(node.decorator_list, *context_manager_names):
@@ -251,6 +295,12 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
251295
self.check_for_trio109(node)
252296
self.visit_FunctionDef(node)
253297

298+
def visit_Lambda(self, node: ast.Lambda):
299+
outer = self.get_state()
300+
self.set_state(self.defaults, copy=True)
301+
self.generic_visit(node)
302+
self.set_state(outer)
303+
254304
# ---- 101 ----
255305
def visit_Yield(self, node: ast.Yield):
256306
if self._yield_is_error:
@@ -260,8 +310,11 @@ def visit_Yield(self, node: ast.Yield):
260310

261311
# ---- 109 ----
262312
def check_for_trio109(self, node: ast.AsyncFunctionDef):
313+
# pending configuration or a more sophisticated check, ignore
314+
# all functions with a decorator
263315
if node.decorator_list:
264316
return
317+
265318
args = node.args
266319
for arg in (*args.posonlyargs, *args.args, *args.kwonlyargs):
267320
if arg.arg == "timeout":
@@ -277,6 +330,7 @@ def visit_Import(self, node: ast.Import):
277330
for name in node.names:
278331
if name.name == "trio" and name.asname is not None:
279332
self.error("TRIO106", node)
333+
self.generic_visit(node)
280334

281335
# ---- 110 ----
282336
def visit_While(self, node: ast.While):
@@ -292,6 +346,53 @@ def check_for_trio110(self, node: ast.While):
292346
):
293347
self.error("TRIO110", node)
294348

349+
# ---- 111 ----
350+
# if it's a <X>.start[_soon] call
351+
# and <X> is a nursery listed in self._context_managers:
352+
# Save <X>'s index in self._context_managers to guard against cm's higher in the
353+
# stack being passed as parameters to it. (and save <X> for the error message)
354+
def visit_Call(self, node: ast.Call):
355+
outer = self.get_state("_nursery_call")
356+
357+
if (
358+
isinstance(node.func, ast.Attribute)
359+
and isinstance(node.func.value, ast.Name)
360+
and node.func.attr in ("start", "start_soon")
361+
):
362+
self._nursery_call = None
363+
for i, cm in enumerate(self._context_managers):
364+
if node.func.value.id == cm.name:
365+
# don't break upon finding a nursery in case there's multiple cm's
366+
# on the stack with the same name
367+
if cm.is_nursery:
368+
self._nursery_call = self.NurseryCall(i, node.func.attr)
369+
else:
370+
self._nursery_call = None
371+
372+
self.generic_visit(node)
373+
self.set_state(outer)
374+
375+
# If we're inside a <X>.start[_soon] call (where <X> is a nursery),
376+
# and we're accessing a variable cm that's on the self._context_managers stack,
377+
# with a higher index than <X>:
378+
# Raise error since the scope of cm may close before the function passed to the
379+
# nursery finishes.
380+
def visit_Name(self, node: ast.Name):
381+
self.generic_visit(node)
382+
if self._nursery_call is None:
383+
return
384+
385+
for i, cm in enumerate(self._context_managers):
386+
if cm.name == node.id and i > self._nursery_call.stack_index:
387+
self.error(
388+
"TRIO111",
389+
node,
390+
cm.lineno,
391+
self._context_managers[self._nursery_call.stack_index].lineno,
392+
node.id,
393+
self._nursery_call.name,
394+
)
395+
295396
# if with has a withitem `trio.open_nursery() as <X>`,
296397
# and the body is only a single expression <X>.start[_soon](),
297398
# and does not pass <X> as a parameter to the expression
@@ -323,6 +424,7 @@ def check_for_trio112(self, node: Union[ast.With, ast.AsyncWith]):
323424
self.error("TRIO112", item.context_expr, var_name)
324425

325426

427+
# used in 102, 103 and 104
326428
def critical_except(node: ast.ExceptHandler) -> Optional[Statement]:
327429
def has_exception(node: Optional[ast.expr]) -> str:
328430
if isinstance(node, ast.Name) and node.id == "BaseException":

tests/test_flake8_trio.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -72,20 +72,27 @@ def test_eval(test: str, path: str):
7272
try:
7373
# Append a bunch of empty strings so string formatting gives garbage
7474
# instead of throwing an exception
75-
args = eval(
76-
f"[{reg_match}]",
77-
{
78-
"lineno": lineno,
79-
"line": lineno,
80-
"Statement": Statement,
81-
"Stmt": Statement,
82-
},
83-
)
75+
try:
76+
args = eval(
77+
f"[{reg_match}]",
78+
{
79+
"lineno": lineno,
80+
"line": lineno,
81+
"Statement": Statement,
82+
"Stmt": Statement,
83+
},
84+
)
85+
except NameError:
86+
print(f"failed to eval on line {lineno}", file=sys.stderr)
87+
raise
8488

8589
except Exception as e:
8690
print(f"lineno: {lineno}, line: {line}", file=sys.stderr)
8791
raise e
88-
col, *args = args
92+
if args:
93+
col, *args = args
94+
else:
95+
col = 0
8996
assert isinstance(
9097
col, int
9198
), f'invalid column "{col}" @L{lineno}, in "{line}"'
@@ -163,13 +170,15 @@ def assert_expected_errors(plugin: Plugin, include: Iterable[str], *expected: Er
163170

164171
def print_first_diff(errors: Sequence[Error], expected: Sequence[Error]):
165172
first_error_line: List[Error] = []
166-
for e in errors:
167-
if e.line == errors[0].line:
168-
first_error_line.append(e)
169173
first_expected_line: List[Error] = []
170-
for e in expected:
171-
if e.line == expected[0].line:
172-
first_expected_line.append(e)
174+
for err, exp in zip(errors, expected):
175+
if err == exp:
176+
continue
177+
if not first_error_line or err.line == first_error_line[0]:
178+
first_error_line.append(err)
179+
if not first_expected_line or exp.line == first_expected_line[0]:
180+
first_expected_line.append(exp)
181+
173182
if first_expected_line != first_error_line:
174183
print(
175184
"First lines with different errors",

0 commit comments

Comments
 (0)
0