8000 implement autofix for 91x · python-trio/flake8-async@0bf4add · GitHub
[go: up one dir, main page]

Skip to content

Commit 0bf4add

Browse files
committed
implement autofix for 91x
1 parent dc83998 commit 0bf4add

File tree

8 files changed

+2131
-32
lines changed

8 files changed

+2131
-32
lines changed

flake8_trio/visitors/visitor91x.py

Lines changed: 147 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from __future__ import annotations
99

1010
from dataclasses import dataclass, field
11-
from typing import Any
11+
from typing import TYPE_CHECKING, Any
1212

1313
import libcst as cst
1414
import libcst.matchers as m
@@ -24,9 +24,19 @@
2424
iter_guaranteed_once_cst,
2525
)
2626

27+
if TYPE_CHECKING:
28+
from collections.abc import Sequence
29+
30+
2731
ARTIFICIAL_STATEMENT = Statement("artificial", -1)
2832

2933

34+
def checkpoint_statement() -> cst.SimpleStatementLine:
35+
return cst.SimpleStatementLine(
36+
[cst.Expr(cst.parse_expression("trio.lowlevel.checkpoint()"))]
37+
)
38+
39+
3040
def func_empty_body(node: cst.FunctionDef) -> bool:
3141
# Does the function body consist solely of `pass`, `...`, and (doc)string literals?
3242
empty_statement = m.Pass() | m.Expr(m.Ellipsis() | m.SimpleString())
@@ -46,8 +56,10 @@ class LoopState:
4656

4757
uncheckpointed_before_continue: set[Statement] = field(default_factory=set)
4858
uncheckpointed_before_break: set[Statement] = field(default_factory=set)
49-
artificial_errors: set[cst.Return | cst.FunctionDef | cst.Yield] = field(
50-
default_factory=set
59+
60+
artificial_errors: set[cst.Return | cst.Yield] = field(default_factory=set)
61+
nodes_needing_checkpoints: list[cst.Yield | cst.Return] = field(
62+
default_factory=list
5163
)
5264

5365
def copy(self):
@@ -58,6 +70,7 @@ def copy(self):
5870
uncheckpointed_before_continue=self.uncheckpointed_before_continue.copy(),
5971
uncheckpointed_before_break=self.uncheckpointed_before_break.copy(),
6072
artificial_errors=self.artificial_errors.copy(),
73+
nodes_needing_checkpoints=self.nodes_needing_checkpoints.copy(),
6174
)
6275

6376

@@ -99,6 +112,10 @@ def __init__(self, *args: Any, **kwargs: Any):
99112
self.uncheckpointed_statements: set[Statement] = set()
100113
self.comp_unknown = False
101114

115+
# this one is not save-stated, but I fail to come up with any scenario
116+
# where that matters
117+
self.add_statement: cst.SimpleStatementLine | None = None
118+
102119
self.loop_state = LoopState()
103120
self.try_state = TryState()
104121

@@ -142,14 +159,6 @@ def uncheckpointed_before_break(self) -> set[Statement]:
142159
def uncheckpointed_before_break(self, value: set[Statement]):
143160
self.loop_state.uncheckpointed_before_break = value
144161

145-
@property
146-
def artificial_errors(self) -> set[cst.Return | cst.FunctionDef | cst.Yield]:
147-
return self.loop_state.artificial_errors
148-
149-
@artificial_errors.setter
150-
def artificial_errors(self, value: set[cst.Return | cst.FunctionDef | cst.Yield]):
151-
self.loop_state.artificial_errors = value # pragma: no cover
152-
153162
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
154163
# don't lint functions whose bodies solely consist of pass or ellipsis
155164
if func_has_decorator(node, "overload", "fixture") or func_empty_body(node):
@@ -186,38 +195,93 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
186195
def leave_FunctionDef(
187196
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
188197
) -> cst.FunctionDef:
198+
any_error = False
189199
if self.async_function:
190200
# updated_node does not have a Position, so we must send original_node
191-
self.check_function_exit(original_node)
201+
any_error = self.check_function_exit(original_node)
192202
self.restore_state(original_node)
203+
if (
204+
any_error
205+
and self.options.autofix
206+
and isinstance(updated_node.body, cst.IndentedBlock)
207+
):
208+
new_body = list(updated_node.body.body) + [checkpoint_statement()]
209+
indentedblock = updated_node.body.with_changes(body=new_body)
210+
return updated_node.with_changes(body=indentedblock)
193211
return updated_node
194212

195213
# error if function exits or returns with uncheckpointed statements
196-
def check_function_exit(self, node: cst.Return | cst.FunctionDef):
214+
def check_function_exit(
215+
self,
216+
original_node: cst.Return | cst.FunctionDef | cst.Yield,
217+
) -> bool:
218+
any_error = False
219+
220+
# Add this as a node potentially needing checkpoints only if it
221+
# solely depends on whether the artificial statement is "real"
222+
if self.uncheckpointed_statements == {ARTIFICIAL_STATEMENT}:
223+
assert isinstance(original_node, (cst.Return, cst.Yield))
224+
self.loop_state.nodes_needing_checkpoints.append(original_node)
225+
197226
for statement in self.uncheckpointed_statements:
198-
self.error_91x(node, statement)
227+
any_error |= self.error_91x(original_node, statement)
228+
return any_error
199229

200230
def leave_Return(
201231
self, original_node: cst.Return, updated_node: cst.Return
202232
) -> cst.Return:
203233
if not self.async_function:
204234
return updated_node
205-
self.check_function_exit(original_node)
235+
if self.check_function_exit(original_node):
236+
self.add_statement = checkpoint_statement()
206237
# avoid duplicate error messages
207238
self.uncheckpointed_statements = set()
208239

240+
# return original node to avoid problems with identity equality
241+
assert original_node.deep_equals(updated_node)
242+
return original_node
243+
244+
# this could probably be handled, but I'll settle for not inserting the checkpoint
245+
# in the wrong place
246+
def leave_SimpleStatementSuite(
247+
self,
248+
original_node: cst.SimpleStatementSuite,
249+
updated_node: cst.SimpleStatementSuite,
250+
) -> cst.SimpleStatementSuite:
251+
self.add_statement = None
209252
return updated_node
210253

254+
def leave_SimpleStatementLine(
255+
self,
256+
original_node: cst.SimpleStatementLine,
257+
updated_node: cst.SimpleStatementLine,
258+
) -> cst.SimpleStatementLine | cst.FlattenSentinel[cst.SimpleStatementLine]:
259+
if self.add_statement is None:
260+
return updated_node
261+
# multiple statements on a single line is not handled
262+
# TODO: generate an error here if transforming+visiting is done in a single pass
263+
# and emit-error-on-transform is enabled
264+
if len(updated_node.body) > 1:
265+
self.add_statement = None
266+
return updated_node
267+
268+
res = cst.FlattenSentinel([self.add_statement, updated_node])
269+
self.add_statement = None
270+
return res # noqa: R504
271+
211272
def error_91x(
212-
self, node: cst.Return | cst.FunctionDef | cst.Yield, statement: Statement
213-
):
273+
self,
274+
node: cst.Return | cst.FunctionDef | cst.Yield,
275+
statement: Statement,
276+
) -> bool:
214277
# artificial statement is injected in visit_While_body to make sure errors
215278
# are raised on multiple loops, if e.g. the end of a loop is uncheckpointed.
216279
# Here we add it to artificial errors, so loop logic can later turn it into
217280
# a real error if needed.
218281
if statement == ARTIFICIAL_STATEMENT:
282+
assert isinstance(node, (cst.Return, cst.Yield))
219283
self.loop_state.artificial_errors.add(node)
220-
return
284+
return False
221285
if isinstance(node, cst.FunctionDef):
222286
msg = "exit"
223287
else:
@@ -229,6 +293,7 @@ def error_91x(
229293
statement,
230294
error_code="TRIO911" if self.has_yield else "TRIO910",
231295
)
296+
return True
232297

233298
def leave_Await(
234299
self, original_node: cst.Await, updated_node: cst.Await
@@ -260,13 +325,16 @@ def leave_Yield(
260325
if not self.async_function:
261326
return updated_node
262327
self.has_yield = True
263-
for statement in self.uncheckpointed_statements:
264-
self.error_91x(original_node, statement)
328+
329+
if self.check_function_exit(original_node):
330+
self.add_statement = checkpoint_statement()
265331

266332
# mark as requiring checkpoint after
267333
pos = self.get_metadata(PositionProvider, original_node).start
268334
self.uncheckpointed_statements = {Statement("yield", pos.line, pos.column)}
269-
return updated_node
335+
# return original to avoid problems with identity equality
336+
assert original_node.deep_equals(updated_node)
337+
return original_node
270338

271339
# valid checkpoint if there's valid checkpoints (or raise) in:
272340
# (try or else) and all excepts, or in finally
@@ -283,6 +351,7 @@ def visit_Try(self, node: cst.Try):
283351
self.try_state.body_uncheckpointed_statements = (
284352
self.uncheckpointed_statements.copy()
285353
)
354+
# yields inside `try` can always be uncheckpointed
286355
for inner_node in m.findall(node.body, m.Yield()):
287356
pos = self.get_metadata(PositionProvider, inner_node).start
288357
self.try_state.body_uncheckpointed_statements.add(
@@ -394,6 +463,9 @@ def visit_While(self, node: cst.While | cst.For):
394463
self.loop_state = LoopState()
395464
self.infinite_loop = self.body_guaranteed_once = False
396465

466+
# big bug not having this, TODO: add test case
467+
visit_For = visit_While
468+
397469
# Check for yields w/o checkpoint in between due to entering loop body the first time,
398470
# after completing all of loop body, and after any continues.
399471
# yield in else have same requirement
@@ -409,7 +481,7 @@ def visit_While_test(self, node: cst.While):
409481
def visit_For_iter(self, node: cst.For):
410482
self.body_guaranteed_once = iter_guaranteed_once_cst(node.iter)
411483

412-
def visit_While_body(self, node: cst.While | cst.For):
484+
def visit_While_body(self, node: cst.For | cst.While):
413485
if not self.async_function:
414486
return
415487

@@ -432,7 +504,7 @@ def visit_While_body(self, node: cst.While | cst.For):
432504

433505
visit_For_body = visit_While_body
434506

435-
def leave_While_body(self, node: cst.While | cst.For):
507+
def leave_While_body(self, node: cst.For | cst.While):
436508
if not self.async_function:
437509
return
438510
# if there's errors due to the artificial statement
@@ -442,11 +514,16 @@ def leave_While_body(self, node: cst.While | cst.For):
442514
self.outer[node]["uncheckpointed_statements"]
443515
| self.uncheckpointed_statements
444516
)
445-
for err_node in self.artificial_errors:
517+
any_error = False
518+
519+
for err_node in self.loop_state.artificial_errors:
446520
for stmt in (
447521
new_uncheckpointed_statements | self.uncheckpointed_before_continue
448522
):
523+
any_error = True
449524
self.error_91x(err_node, stmt)
525+
if not any_error:
526+
self.loop_state.nodes_needing_checkpoints = []
450527

451528
# replace artificial in break with prebody uncheckpointed statements
452529
for stmts in (
@@ -479,7 +556,7 @@ def leave_While_body(self, node: cst.While | cst.For):
479556

480557
leave_For_body = leave_While_body
481558

482-
def leave_While_orelse(self, node: cst.While | cst.For):
559+
def leave_While_orelse(self, node: cst.For | cst.While):
483560
if not self.async_function:
484561
return
485562
# if this is an infinite loop, with no break in it, don't raise
@@ -494,10 +571,30 @@ def leave_While_orelse(self, node: cst.While | cst.For):
494571

495572
# reset break & continue in case of nested loops
496573
self.outer[node]["uncheckpointed_statements"] = self.uncheckpointed_statements
497-
self.restore_state(node)
498574

499575
leave_For_orelse = leave_While_orelse
500576

577+
def leave_While(
578+
self, original_node: cst.For | cst.While, updated_node: cst.For | cst.While
579+
) -> (
580+
cst.While
581+
| cst.For
582+
| cst.FlattenSentinel[cst.For | cst.While]
583+
| cst.RemovalSentinel
584+
):
585+
if self.loop_state.nodes_needing_checkpoints:
586+
transformer = InsertCheckpointsInLoopBody(
587+
self.loop_state.nodes_needing_checkpoints
588+
)
589+
res = updated_node.visit(transformer)
590+
else:
591+
res = updated_node
592+
593+
self.restore_state(original_node)
594+
return res # noqa: R504
595+
596+
leave_For = leave_While
597+
501598
# save state in case of continue/break at a point not guaranteed to checkpoint
502599
def visit_Continue(self, node: cst.Continue):
503600
if not self.async_function:
@@ -590,3 +687,27 @@ def visit_CompFor(self, node: cst.CompFor):
590687
# ignore their content
591688
def visit_GeneratorExp(self, node: cst.GeneratorExp):
592689
return False
690+
691+
692+
# necessary as we don't know whether to insert checkpoints on the first pass of a loop
693+
# so we transform the loop body afterwards
694+
class InsertCheckpointsInLoopBody(cst.CSTTransformer):
695+
def __init__(self, nodes_needing_checkpoint: Sequence[cst.Yield | cst.Return]):
696+
super().__init__()
697+
self.nodes_needing_checkpoint = nodes_needing_checkpoint
698+
self.add_statement: cst.SimpleStatementLine | None = None
699+
700+
leave_SimpleStatementLine = Visitor91X.leave_SimpleStatementLine # type: ignore
701+
702+
def leave_Yield(
703+
self,
704+
original_node: cst.Yield,
705+
updated_node: cst.Yield,
706+
) -> cst.Yield:
707+
# we need to check *original* node here, since updated node is a copy
708+
# which loses identity equality
709+
if original_node in self.nodes_needing_checkpoint:
710+
self.add_statement = checkpoint_statement()
711+
return updated_node
712+
713+
leave_Return = leave_Yield # type: ignore

0 commit comments

Comments
 (0)
0