8000 First draft of async iterables must checkpoint between yields, and af… · python-trio/flake8-async@7b46176 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7b46176

Browse files
committed
First draft of async iterables must checkpoint between yields, and after last one
1 parent 48b3e3c commit 7b46176

10 files changed

+624
-50
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Changelog
22
*[CalVer, YY.month.patch](https://calver.org/)*
33

4+
## Future
5+
- Add TRIO109: Async iterables should have a checkpoint after each yield.
6+
47
## 22.7.5
58
- Add TRIO103: `except BaseException` or `except trio.Cancelled` with a code path that doesn't re-raise
69
- Add TRIO104: "Cancelled and BaseException must be re-raised" if user tries to return or raise a different exception.

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ pip install flake8-trio
3131
- **TRIO107**: Async functions must have at least one checkpoint on every code path, unless an exception is raised.
3232
- **TRIO108**: Early return from async function must have at least one checkpoint on every code path before it, unless an exception is raised.
3333
Checkpoints are `await`, `async with` `async for`.
34+
- **TRIO109**: Async iterables should have a checkpoint after each yield.

flake8_trio.py

Lines changed: 165 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class Flake8TrioVisitor(ast.NodeVisitor):
4040
def __init__(self) -> None:
4141
super().__init__()
4242
self.problems: List[Error] = []
43+
self.suppress_errors = False
4344

4445
@classmethod
4546
def run(cls, tree: ast.AST) -> Generator[Error, None, None]:
@@ -55,7 +56,8 @@ def visit_nodes(self, nodes: Union[ast.AST, Iterable[ast.AST]]) -> None:
5556
self.visit(node)
5657

5758
def error(self, error: str, lineno: int, col: int, *args: Any, **kwargs: Any):
58-
self.problems.append(make_error(error, lineno, col, *args, **kwargs))
59+
if not self.suppress_errors:
60+
self.problems.append(make_error(error, lineno, col, *args, **kwargs))
5961

6062

6163
class TrioScope:
@@ -472,108 +474,235 @@ def visit_Call(self, node: ast.Call):
472474
self.generic_visit(node)
473475

474476

475-
class Visitor107_108(Flake8TrioVisitor):
477+
class Visitor107_108_109(Flake8TrioVisitor):
476478
def __init__(self) -> None:
477479
super().__init__()
478480
self.all_await = True
481+
self.checkpoint_after_yield = True
482+
self.has_yield = False
483+
self.checkpoint_continue = True
484+
self.checkpoint_break = True
485+
486+
def get_state(self) -> Tuple[bool, bool, bool, bool]:
487+
return (
488+
self.all_await,
489+
self.checkpoint_after_yield,
490+
self.has_yield,
491+
self.suppress_errors,
492+
)
493+
A93C 494+
def set_state(self, state: Tuple[bool, bool, bool, bool]):
495+
(
496+
self.all_await,
497+
self.checkpoint_after_yield,
498+
self.has_yield,
499+
self.suppress_errors,
500+
) = state
479501

480502
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
481-
outer = self.all_await
503+
outer = self.get_state()
482504

483505
# do not require checkpointing if overloading
484506
self.all_await = has_decorator(node.decorator_list, "overload")
507+
# don't need checkpoint before first yield
508+
self.checkpoint_after_yield = True
509+
485510
self.generic_visit(node)
486511

487512
if not self.all_await:
488513
self.error(TRIO107, node.lineno, node.col_offset)
514+
if self.has_yield and not self.checkpoint_after_yield:
515+
self.error(TRIO109, node.lineno, node.col_offset)
489516

490-
self.all_await = outer
517+
self.set_state(outer)
491518

492519
def visit_Return(self, node: ast.Return):
493520
self.generic_visit(node)
494521
if not self.all_await:
495522
self.error(TRIO108, node.lineno, node.col_offset)
523+
if self.has_yield and not self.checkpoint_after_yield:
524+
self.error(TRIO109, node.lineno, node.col_offset)
525+
496526
# avoid duplicate error messages
497527
self.all_await = True
498528

499-
# disregard raise's in nested functions
529+
# disregard checkpoints in nested functions
500530
def visit_FunctionDef(self, node: ast.FunctionDef):
501-
outer = self.all_await
531+
outer = self.get_state()
502532
self.generic_visit(node)
503-
self.all_await = outer
533+
self.set_state(outer)
504534

505535
# checkpoint functions
506-
def visit_Await(
507-
self, node: Union[ast.Await, ast.AsyncFor, ast.AsyncWith, ast.Raise]
508-
):
536+
def visit_Await(self, node: Union[ast.Await, ast.Raise]):
537+
# the expression being awaited is not checkpointed
538+
# so only set checkpoint after the await node
509539
self.generic_visit(node)
510-
self.all_await = True
511-
512-
visit_AsyncFor = visit_Await
513-
visit_AsyncWith = visit_Await
540+
self.all_await = self.checkpoint_after_yield = True
514541

515542
# raising exception means we don't need to checkpoint so we can treat it as one
516543
visit_Raise = visit_Await
517544

545+
# checkpoint on enter and exit of with body
546+
def visit_AsyncWith(self, node: ast.AsyncWith):
547+
self.visit_nodes(node.items)
548+
self.all_await = self.checkpoint_after_yield = True
549+
self.visit_nodes(node.body)
550+
self.all_await = self.checkpoint_after_yield = True
551+
552+
def visit_Yield(self, node: ast.Yield):
553+
self.generic_visit(node)
554+
self.has_yield = True
555+
if not self.checkpoint_after_yield:
556+
self.error(TRIO109, node.lineno, node.col_offset)
557+
self.checkpoint_after_yield = False
558+
518559
# valid checkpoint if there's valid checkpoints (or raise) in at least one of:
519560
# (try or else) and all excepts
520561
# finally
562+
# 109: if yield_checkpoint == True after body+else and all excepts, set True
521563
def visit_Try(self, node: ast.Try):
522-
if self.all_await:
523-
self.generic_visit(node)
524-
return
564+
outer_await = self.all_await
565+
outer_checkpoint = self.checkpoint_after_yield
525566

526567
# check try body
527568
self.visit_nodes(node.body)
528569
body_await = self.all_await
570+
try_checkpoint = self.checkpoint_after_yield
571+
572+
# TODO: write test
573+
worst_case_try_checkpoint = outer_checkpoint and not any(
574+
isinstance(n, ast.Yield) for body in node.body for n in ast.walk(body)
575+
)
529576
self.all_await = False
530577

531578
# check that all except handlers checkpoint (await or most likely raise)
532579
all_except_await = True
580+
all_except_checkpoint = True
533581
for handler in node.handlers:
582+
# if there's any `yield`s in try body, exception might be thrown there
583+
self.checkpoint_after_yield = worst_case_try_checkpoint
584+
534585
self.visit_nodes(handler)
535586
all_except_await &= self.all_await
587+
all_except_checkpoint &= self.checkpoint_after_yield
536588
self.all_await = False
537589

538590
# check else
591+
# if else runs it's after all of try, so restore state to back then
592+
self.checkpoint_after_yield = try_checkpoint
539593
self.visit_nodes(node.orelse)
540594

541-
# (try or else) and all excepts
542-
self.all_await = (body_await or self.all_await) and all_except_await
595+
# outer or ((try or else) and all excepts)
596+
self.all_await = outer_await or (
597+
(body_await or self.all_await) and all_except_await
598+
)
599+
600+
self.checkpoint_after_yield &= all_except_checkpoint
543601

544-
# finally can check on it's own
545-
self.visit_nodes(node.finalbody)
602+
if node.finalbody:
603+
# if there's a finally, it can get jumped to at the worst time
604+
self.checkpoint_after_yield &= worst_case_try_checkpoint
605+
self.visit_nodes(node.finalbody)
606+
607+
# if any body unsets, keep unset
608+
# if all bodies checkpoint, checkpoint
546609

547610
# valid checkpoint if both body and orelse have checkpoints
548611
def visit_If(self, node: Union[ast.If, ast.IfExp]):
549-
if self.all_await:
550-
self.generic_visit(node)
551-
return
552-
553-
# ignore checkpoints in condition
612+
# checkpoints in condition happen in both branches
554613
self.visit_nodes(node.test)
555-
self.all_await = False
614+
cond_await = self.all_await
615+
cond_yield = self.checkpoint_after_yield
556616

557617
# check body
558618
self.visit_nodes(node.body)
559619
body_await = self.all_await
560-
self.all_await = False
620+
body_yield = self.checkpoint_after_yield
561621

622+
# reset to cond and check orelse
623+
self.all_await = cond_await
624+
self.checkpoint_after_yield = cond_yield
562625
self.visit_nodes(node.orelse)
563626

564-
# checkpoint if both body and else
565-
self.all_await = body_await and self.all_await
627+
# checkpoint if both body and else checkpoint
628+
self.all_await &= body_await
629+
self.checkpoint_after_yield &= body_yield
566630

567631
# inline if
568632
visit_IfExp = visit_If
569633

570-
# ignore checkpoints in loops due to continue/break shenanigans
571-
def visit_While(self, node: Union[ast.While, ast.For]):
572-
outer = self.all_await
573-
self.generic_visit(node)
574-
self.all_await = outer
634+
# checkpoints in loop condition are valid, but ignore checkpoints in loop bodies
635+
# due to continue/break/zero-iter shenanigans
636+
def visit_While(self, node: Union[ast.While, ast.For, ast.AsyncFor]):
637+
outer_suppress_errors = self.suppress_errors
638+
outer = self.checkpoint_continue, self.checkpoint_break
639+
if isinstance(node, ast.While):
640+
self.visit_nodes(node.test)
641+
else:
642+
self.visit_nodes(node.target)
643+
self.visit_nodes(node.iter)
644+
645+
self.checkpoint_continue = self.checkpoint_break = True
646+
647+
# Async for always enters and exit loop body with checkpoint
648+
# and does not care about continue
649+
if isinstance(node, ast.AsyncFor):
650+
pre_body_await = True
651+
pre_body_yield = True
652+
self.checkpoint_after_yield = True
653+
else:
654+
pre_body_await = self.all_await
655+
pre_body_yield = self.checkpoint_after_yield
656+
657+
# silently check if body unsets yield
658+
# so we later can check if body errors out on worst case of entering
659+
self.suppress_errors = True
660+
self.checkpoint_after_yield = True
661+
self.visit_nodes(node.body)
662+
self.suppress_errors = outer_suppress_errors
663+
664+
# self.checkpoint_continue is set to False if loop body ever does
665+
# continue/break with self.checkpoint_after_yield == False
666+
667+
# enter with True only if both ways of entering loop body give True
668+
self.checkpoint_after_yield &= pre_body_yield and self.checkpoint_continue
669+
670+
self.visit_nodes(node.body)
671+
672+
# enter orelse with worst case: loop body might execute fully before
673+
# entering orelse, or not at all, at a break (or at a continue)
674+
if isinstance(node, ast.AsyncFor):
675+
self.checkpoint_after_yield = True
676+
else:
677+
self.checkpoint_after_yield &= pre_body_yield
678+
pre_orelse_yield = self.checkpoint_after_yield
679+
self.visit_nodes(node.orelse)
680+
681+
# exit with worst case (orelse might or might not execute)
682+
self.checkpoint_after_yield &= pre_orelse_yield and self.checkpoint_break
683+
684+
# reset all_await to before body, since checkpoints in either might not execute
685+
self.all_await = pre_body_await
686+
687+
# TODO: write tests that require this line
688+
self.checkpoint_continue, self.checkpoint_break = outer
575689

576690
visit_For = visit_While
691+
visit_AsyncFor = visit_While
692+
693+
def visit_Continue(self, node: ast.Continue):
694+
self.checkpoint_continue &= self.checkpoint_after_yield
695+
696+
def visit_Break(self, node: ast.Break):
697+
self.checkpoint_break &= self.checkpoint_after_yield
698+
699+
# first node in boolops can checkpoint, the others might not execute
700+
def visit_BoolOp(self, node: ast.BoolOp):
701+
self.visit(node.op)
702+
self.visit_nodes(node.values[:1])
703+
outer = self.all_await
704+
self.visit_nodes(node.values[1:])
705+
self.all_await = outer
577706

578707

579708
class Plugin:
@@ -603,3 +732,4 @@ def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]:
603732
TRIO106 = "TRIO106: trio must be imported with `import trio` for the linter to work"
604733
TRIO107 = "TRIO107: Async functions must hav 3EC5 e at least one checkpoint on every code path, unless an exception is raised"
605734
TRIO108 = "TRIO108: Early return from async function must have at least one checkpoint on every code path before it."
735+
TRIO109 = "TRIO109: Async iterables should have a checkpoint after each yield."

0 commit comments

Comments
 (0)
0