diff --git a/CHANGELOG.md b/CHANGELOG.md index 007454e8..0e58b85e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,11 @@ # Changelog *[CalVer, YY.month.patch](https://calver.org/)* +## 22.7.6 +- Extend TRIO102 to also check inside `except BaseException` and `except trio.Cancelled` +- Extend TRIO104 to also check for `yield` +- Update error messages on TRIO102 and TRIO103 + ## 22.7.5 - Add TRIO103: `except BaseException` or `except trio.Cancelled` with a code path that doesn't re-raise - Add TRIO104: "Cancelled and BaseException must be re-raised" if user tries to return or raise a different exception. diff --git a/README.md b/README.md index f7a90191..0c9e1b2b 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ pip install flake8-trio context does not contain any `await` statements. This makes it pointless, as the timeout can only be triggered by a checkpoint. - **TRIO101**: `yield` inside a nursery or cancel scope is only safe when implementing a context manager - otherwise, it breaks exception handling. -- **TRIO102**: it's unsafe to await inside `finally:` unless you use a shielded +- **TRIO102**: it's unsafe to await inside `finally:` or `except BaseException/trio.Cancelled` unless you use a shielded cancel scope with a timeout. - **TRIO103**: `except BaseException` and `except trio.Cancelled` with a code path that doesn't re-raise. Note that any `raise` statements in loops are ignored since it's tricky to parse loop flow with `break`, `continue` and/or the zero-iteration case. - **TRIO104**: `Cancelled` and `BaseException` must be re-raised - when a user tries to `return` or `raise` a different exception. diff --git a/flake8_trio.py b/flake8_trio.py index af3885aa..49deb493 100644 --- a/flake8_trio.py +++ b/flake8_trio.py @@ -14,7 +14,7 @@ from typing import Any, Generator, Iterable, List, Optional, Tuple, Type, Union # CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1" -__version__ = "22.7.5" +__version__ = "22.7.6" Error = Tuple[int, int, str, Type[Any]] @@ -37,7 +37,7 @@ def make_error(error: str, lineno: int, col: int, *args: Any, **kwargs: Any) -> class Flake8TrioVisitor(ast.NodeVisitor): - def __init__(self) -> None: + def __init__(self): super().__init__() self.problems: List[Error] = [] @@ -47,12 +47,19 @@ def run(cls, tree: ast.AST) -> Generator[Error, None, None]: visitor.visit(tree) yield from visitor.problems - def visit_nodes(self, nodes: Union[ast.AST, Iterable[ast.AST]]) -> None: - if isinstance(nodes, ast.AST): - self.visit(nodes) + def visit_nodes( + self, *nodes: Union[ast.AST, Iterable[ast.AST]], generic: bool = False + ): + if generic: + visit = self.generic_visit else: - for node in nodes: - self.visit(node) + visit = self.visit + for arg in nodes: + if isinstance(arg, ast.AST): + visit(arg) + else: + for node in arg: + visit(node) def error(self, error: str, lineno: int, col: int, *args: Any, **kwargs: Any): self.problems.append(make_error(error, lineno, col, *args, **kwargs)) @@ -109,18 +116,18 @@ def has_decorator(decorator_list: List[ast.expr], *names: str): # handles 100, 101 and 106 class VisitorMiscChecks(Flake8TrioVisitor): - def __init__(self) -> None: + def __init__(self): super().__init__() self._yield_is_error = False - self._context_manager = False + self._safe_decorator = False - def visit_With(self, node: Union[ast.With, ast.AsyncWith]) -> None: + def visit_With(self, node: Union[ast.With, ast.AsyncWith]): self.check_for_trio100(node) outer_yie = self._yield_is_error # Check for a `with trio.` - if not self._context_manager: + if not self._safe_decorator: for item in (i.context_expr for i in node.items): if ( get_trio_scope(item, "open_nursery", *cancel_scope_names) @@ -134,35 +141,31 @@ def visit_With(self, node: Union[ast.With, ast.AsyncWith]) -> None: # reset yield_is_error self._yield_is_error = outer_yie - def visit_AsyncWith(self, node: ast.AsyncWith) -> None: + def visit_AsyncWith(self, node: ast.AsyncWith): self.visit_With(node) - def visit_FunctionDef( - self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef] - ) -> None: - outer_cm = self._context_manager - outer_yie = self._yield_is_error + def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]): + outer = self._safe_decorator, self._yield_is_error self._yield_is_error = False # check for @ and @. if has_decorator(node.decorator_list, *context_manager_names): - self._context_manager = True + self._safe_decorator = True self.generic_visit(node) - self._context_manager = outer_cm - self._yield_is_error = outer_yie + self._safe_decorator, self._yield_is_error = outer - def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): self.visit_FunctionDef(node) - def visit_Yield(self, node: ast.Yield) -> None: + def visit_Yield(self, node: ast.Yield): if self._yield_is_error: self.problems.append(make_error(TRIO101, node.lineno, node.col_offset)) self.generic_visit(node) - def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]) -> None: + def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]): # Context manager with no `await` call within for item in (i.context_expr for i in node.items): call = get_trio_scope(item, *cancel_scope_names) @@ -185,112 +188,153 @@ def visit_Import(self, node: ast.Import): self.problems.append(make_error(TRIO106, node.lineno, node.col_offset)) +def critical_except(node: ast.ExceptHandler) -> Optional[Tuple[int, int, str]]: + def has_exception(node: Optional[ast.expr]) -> str: + if isinstance(node, ast.Name) and node.id == "BaseException": + return "BaseException" + if ( + isinstance(node, ast.Attribute) + and isinstance(node.value, ast.Name) + and node.value.id == "trio" + and node.attr == "Cancelled" + ): + return "trio.Cancelled" + return "" + + # bare except + if node.type is None: + return node.lineno, node.col_offset, "bare except" + # several exceptions + elif isinstance(node.type, ast.Tuple): + for element in node.type.elts: + name = has_exception(element) + if name: + return element.lineno, element.col_offset, name + # single exception, either a Name or an Attribute + else: + name = has_exception(node.type) + if name: + return node.type.lineno, node.type.col_offset, name + return None + + class Visitor102(Flake8TrioVisitor): - def __init__(self) -> None: + def __init__(self): super().__init__() - self._inside_finally: bool = False - self._scopes: List[TrioScope] = [] - self._context_manager = False + self._critical_scope: Optional[Tuple[int, int, str]] = None + self._trio_context_managers: List[TrioScope] = [] + self._safe_decorator = False - def visit_Assign(self, node: ast.Assign) -> None: - # checks for .shield = [True/False] - if self._scopes and len(node.targets) == 1: - last_scope = self._scopes[-1] - target = node.targets[0] - if ( - last_scope.variable_name is not None - and isinstance(target, ast.Attribute) - and isinstance(target.value, ast.Name) - and target.value.id == last_scope.variable_name - and target.attr == "shield" - and isinstance(node.value, ast.Constant) - ): - last_scope.shielded = node.value.value - self.generic_visit(node) + # if we're inside a finally, and not inside a context_manager, and we're not + # inside a scope that doesn't have both a timeout and shield + def visit_Await( + self, + node: Union[ast.Await, ast.AsyncFor, ast.AsyncWith], + visit_children: bool = True, + ): + if ( + self._critical_scope is not None + and not self._safe_decorator + and not any( + cm.has_timeout and cm.shielded for cm in self._trio_context_managers + ) + ): + self.problems.append( + make_error(TRIO102, node.lineno, node.col_offset, *self._critical_scope) + ) + if visit_children: + self.generic_visit(node) - def visit_Await(self, node: ast.Await) -> None: - self.check_for_trio102(node) - self.generic_visit(node) + visit_AsyncFor = visit_Await - def visit_With(self, node: Union[ast.With, ast.AsyncWith]) -> None: - trio_scope = None + def visit_With(self, node: Union[ast.With, ast.AsyncWith]): + has_context_manager = False # Check for a `with trio.` for item in node.items: trio_scope = get_trio_scope( item.context_expr, "open_nursery", *cancel_scope_names ) - if trio_scope is not None: - # check if it's saved in a variable - if isinstance(item.optional_vars, ast.Name): - trio_scope.variable_name = item.optional_vars.id - break + if trio_scope is None: + continue - if trio_scope is not None: - self._scopes.append(trio_scope) + self._trio_context_managers.append(trio_scope) + has_context_manager = True + # check if it's saved in a variable + if isinstance(item.optional_vars, ast.Name): + trio_scope.variable_name = item.optional_vars.id + break self.generic_visit(node) - if trio_scope is not None: - self._scopes.pop() + if has_context_manager: + self._trio_context_managers.pop() - def visit_AsyncWith(self, node: ast.AsyncWith) -> None: - self.check_for_trio102(node) + def visit_AsyncWith(self, node: ast.AsyncWith): + self.visit_Await(node, visit_children=False) self.visit_With(node) - def visit_AsyncFor(self, node: ast.AsyncFor) -> None: - self.check_for_trio102(node) - self.generic_visit(node) - - def visit_FunctionDef( - self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef] - ) -> None: - outer_cm = self._context_manager + def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]): + outer_cm = self._safe_decorator # check for @ and @. if has_decorator(node.decorator_list, *context_manager_names): - self._context_manager = True + self._safe_decorator = True self.generic_visit(node) - self._context_manager = outer_cm + self._safe_decorator = outer_cm - def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: - self.visit_FunctionDef(node) + visit_AsyncFunctionDef = visit_FunctionDef - def visit_Try(self, node: ast.Try) -> None: - # There's no visit_Finally, so we need to manually visit the Try fields. - # It's important to do self.visit instead of self.generic_visit since - # the nodes in the fields might be registered elsewhere in this class. - for item in (*node.body, *node.handlers, *node.orelse): - self.visit(item) + def critical_visit( + self, + node: Union[ast.ExceptHandler, Iterable[ast.AST]], + block: Tuple[int, int, str], + generic: bool = False, + ): + outer = self._critical_scope, self._trio_context_managers - outer = self._inside_finally - outer_scopes = self._scopes + self._trio_context_managers = [] + self._critical_scope = block - self._scopes = [] - self._inside_finally = True + self.visit_nodes(node, generic=generic) + self._critical_scope, self._trio_context_managers = outer - for item in node.finalbody: - self.visit(item) + def visit_Try(self, node: ast.Try): + # There's no visit_Finally, so we need to manually visit the Try fields. + self.visit_nodes(node.body, node.handlers, node.orelse) + self.critical_visit( + node.finalbody, (node.lineno, node.col_offset, "try/finally") + ) - self._scopes = outer_scopes - self._inside_finally = outer + def visit_ExceptHandler(self, node: ast.ExceptHandler): + res = critical_except(node) + if res is None: + self.generic_visit(node) + else: + self.critical_visit(node, res, generic=True) - def check_for_trio102(self, node: Union[ast.Await, ast.AsyncFor, ast.AsyncWith]): - # if we're inside a finally, and not inside a context_manager, and we're not - # inside a scope that doesn't have both a timeout and shield - if ( - self._inside_finally - and not self._context_manager - and not any(scope.has_timeout and scope.shielded for scope in self._scopes) - ): - self.problems.append(make_error(TRIO102, node.lineno, node.col_offset)) + def visit_Assign(self, node: ast.Assign): + # checks for .shield = [True/False] + if self._trio_context_managers and len(node.targets) == 1: + last_scope = self._trio_context_managers[-1] + target = node.targets[0] + if ( + last_scope.variable_name is not None + and isinstance(target, ast.Attribute) + and isinstance(target.value, ast.Name) + and target.value.id == last_scope.variable_name + and target.attr == "shield" + and isinstance(node.value, ast.Constant) + ): + last_scope.shielded = node.value.value + self.generic_visit(node) # Never have an except Cancelled or except BaseException block with a code path that # doesn't re-raise the error class Visitor103_104(Flake8TrioVisitor): - def __init__(self) -> None: + def __init__(self): super().__init__() self.except_name: Optional[str] = "" self.unraised: bool = False @@ -300,45 +344,25 @@ def __init__(self) -> None: # set self.unraised, and if it's still set after visiting child nodes # then there might be a code path that doesn't re-raise. def visit_ExceptHandler(self, node: ast.ExceptHandler): - def has_exception(node: Optional[ast.expr]): - return (isinstance(node, ast.Name) and node.id == "BaseException") or ( - isinstance(node, ast.Attribute) - and isinstance(node.value, ast.Name) - and node.value.id == "trio" - and node.attr == "Cancelled" - ) outer = (self.unraised, self.except_name, self.loop_depth) - marker = None + marker = critical_except(node) - # we need to not unset self.unraised if this is non-critical to still + # we need to *not* unset self.unraised if this is non-critical, to still # warn about `return`s - # bare except - if node.type is None: - self.unraised = True - marker = (node.lineno, node.col_offset) - # several exceptions - elif isinstance(node.type, ast.Tuple): - for element in node.type.elts: - if has_exception(element): - self.unraised = True - marker = element.lineno, element.col_offset - break - # single exception, either a Name or an Attribute - elif has_exception(node.type): - self.unraised = True - marker = node.type.lineno, node.type.col_offset - if marker is not None: - # save name `as ` + # save name from `as ` self.except_name = node.name + self.loop_depth = 0 + self.unraised = True # visit child nodes. Will unset self.unraised if all code paths `raise` self.generic_visit(node) if self.unraised and marker is not None: + # print(marker) self.problems.append(make_error(TRIO103, *marker)) (self.unraised, self.except_name, self.loop_depth) = outer @@ -358,12 +382,14 @@ def visit_Raise(self, node: ast.Raise): self.generic_visit(node) - def visit_Return(self, node: ast.Return): + def visit_Return(self, node: Union[ast.Return, ast.Yield]): if self.unraised: # Error: must re-raise self.problems.append(make_error(TRIO104, node.lineno, node.col_offset)) self.generic_visit(node) + visit_Yield = visit_Return + # Treat Try's as fully covering only if `finally` always raises. def visit_Try(self, node: ast.Try): if not self.unraised: @@ -446,7 +472,7 @@ def visit_Break(self, node: Union[ast.Break, ast.Continue]): class Visitor105(Flake8TrioVisitor): - def __init__(self) -> None: + def __init__(self): super().__init__() self.node_stack: List[ast.AST] = [] @@ -473,7 +499,7 @@ def visit_Call(self, node: ast.Call): class Visitor107_108(Flake8TrioVisitor): - def __init__(self) -> None: + def __init__(self): super().__init__() self.all_await = True @@ -580,7 +606,7 @@ class Plugin: name = __name__ version = __version__ - def __init__(self, tree: ast.AST) -> None: + def __init__(self, tree: ast.AST): self._tree = tree @classmethod @@ -596,8 +622,8 @@ def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]: TRIO100 = "TRIO100: {} context contains no checkpoints, add `await trio.sleep(0)`" TRIO101 = "TRIO101: yield inside a nursery or cancel scope is only safe when implementing a context manager - otherwise, it breaks exception handling" -TRIO102 = "TRIO102: it's unsafe to await inside `finally:` unless you use a shielded cancel scope with a timeout" -TRIO103 = "TRIO103: except Cancelled or except BaseException block with a code path that doesn't re-raise the error" +TRIO102 = "TRIO102: await inside {2} on line {0} must have shielded cancel scope with a timeout" +TRIO103 = "TRIO103: {} block with a code path that doesn't re-raise the error" TRIO104 = "TRIO104: Cancelled (and therefore BaseException) must be re-raised" TRIO105 = "TRIO105: Trio async function {} must be immediately awaited" TRIO106 = "TRIO106: trio must be imported with `import trio` for the linter to work" diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py index 05c564a6..c26339c7 100644 --- a/tests/test_flake8_trio.py +++ b/tests/test_flake8_trio.py @@ -40,9 +40,19 @@ def trim_messages(messages: Iterable[Error]): errors = tuple(plugin.run()) # start with a check with trimmed errors that will make for smaller diff messages - trim_errors = trim_messages(plugin.run()) + trim_errors = trim_messages(errors) trim_expected = trim_messages(expected) - self.assertTupleEqual(trim_errors, trim_expected) + + # + unexpected = sorted(set(trim_errors) - set(trim_expected)) + missing = sorted(set(trim_expected) - set(trim_errors)) + self.assertEqual((unexpected, missing), ([], []), msg="(unexpected, missing)") + + unexpected = sorted(set(errors) - set(expected)) + missing = sorted(set(expected) - set(errors)) + if unexpected and missing: + self.assertEqual(unexpected[0], missing[0]) + self.assertEqual((unexpected, missing), ([], []), msg="(unexpected, missing)") # full check self.assertTupleEqual(errors, expected) @@ -82,59 +92,69 @@ def test_trio101(self): def test_trio102(self): self.assert_expected_errors( "trio102.py", - make_error(TRIO102, 24, 8), - make_error(TRIO102, 30, 12), - make_error(TRIO102, 36, 12), - make_error(TRIO102, 62, 12), - make_error(TRIO102, 70, 12), - make_error(TRIO102, 74, 12), - make_error(TRIO102, 76, 12), - make_error(TRIO102, 80, 12), - make_error(TRIO102, 82, 12), - make_error(TRIO102, 84, 12), - make_error(TRIO102, 88, 12), - make_error(TRIO102, 92, 8), - make_error(TRIO102, 94, 8), - make_error(TRIO102, 101, 12), - make_error(TRIO102, 124, 12), + make_error(TRIO102, 24, 8, 21, 4, "try/finally"), + make_error(TRIO102, 30, 12, 26, 4, "try/finally"), + make_error(TRIO102, 36, 12, 32, 4, "try/finally"), + make_error(TRIO102, 62, 12, 55, 4, "try/finally"), + make_error(TRIO102, 70, 12, 66, 4, "try/finally"), + make_error(TRIO102, 74, 12, 66, 4, "try/finally"), + make_error(TRIO102, 76, 12, 66, 4, "try/finally"), + make_error(TRIO102, 80, 12, 66, 4, "try/finally"), + make_error(TRIO102, 82, 12, 66, 4, "try/finally"), + make_error(TRIO102, 84, 12, 66, 4, "try/finally"), + make_error(TRIO102, 88, 12, 66, 4, "try/finally"), + make_error(TRIO102, 92, 8, 66, 4, "try/finally"), + make_error(TRIO102, 94, 8, 66, 4, "try/finally"), + make_error(TRIO102, 101, 12, 98, 8, "try/finally"), + make_error(TRIO102, 124, 12, 114, 4, "try/finally"), + make_error(TRIO102, 135, 8, 134, 11, "trio.Cancelled"), + make_error(TRIO102, 138, 8, 137, 11, "BaseException"), + make_error(TRIO102, 141, 8, 140, 4, "bare except"), ) def test_trio103_104(self): self.assert_expected_errors( "trio103_104.py", - make_error(TRIO103, 7, 33), - make_error(TRIO103, 15, 7), + make_error(TRIO103, 7, 33, "trio.Cancelled"), + make_error(TRIO103, 15, 7, "trio.Cancelled"), # raise different exception make_error(TRIO104, 20, 4), make_error(TRIO104, 22, 4), make_error(TRIO104, 25, 4), # if - make_error(TRIO103, 28, 7), - make_error(TRIO103, 35, 7), + make_error(TRIO103, 28, 7, "BaseException"), + make_error(TRIO103, 35, 7, "BaseException"), # loops - make_error(TRIO103, 47, 7), - make_error(TRIO103, 52, 7), + make_error(TRIO103, 47, 7, "trio.Cancelled"), + make_error(TRIO103, 52, 7, "trio.Cancelled"), # nested exceptions make_error(TRIO104, 67, 8), # weird edge-case - make_error(TRIO103, 61, 7), + make_error(TRIO103, 61, 7, "BaseException"), make_error(TRIO104, 92, 8), # make_error(TRIO104, 94, 8), # weird edge-case # bare except - make_error(TRIO103, 97, 0), + make_error(TRIO103, 97, 0, "bare except"), # multi-line - make_error(TRIO103, 111, 4), + make_error(TRIO103, 111, 4, "BaseException"), # re-raise parent make_error(TRIO104, 124, 8), # return make_error(TRIO104, 134, 8), - make_error(TRIO103, 133, 11), + make_error(TRIO103, 133, 11, "BaseException"), make_error(TRIO104, 139, 12), make_error(TRIO104, 141, 12), make_error(TRIO104, 143, 12), make_error(TRIO104, 145, 12), - make_error(TRIO103, 137, 11), + make_error(TRIO103, 137, 11, "BaseException"), + # continue/break make_error(TRIO104, 154, 12), make_error(TRIO104, 162, 12), + # yield + make_error(TRIO104, 184, 8), + make_error(TRIO104, 190, 12), + make_error(TRIO104, 192, 12), + make_error(TRIO104, 194, 12), + make_error(TRIO104, 196, 12), ) def test_trio105(self): @@ -192,15 +212,15 @@ def test_trio107_108(self): # try make_error(TRIO107, 83, 0), # early return - make_error(TRIO108, 140, 4), - make_error(TRIO108, 145, 8), + make_error(TRIO108, 141, 4), + make_error(TRIO108, 146, 8), # nested function definition - make_error(TRIO107, 149, 0), - make_error(TRIO107, 159, 4), - make_error(TRIO107, 163, 0), - make_error(TRIO107, 170, 8), - make_error(TRIO107, 168, 0), - make_error(TRIO107, 174, 0), + make_error(TRIO107, 150, 0), + make_error(TRIO107, 160, 4), + make_error(TRIO107, 164, 0), + make_error(TRIO107, 171, 8), + make_error(TRIO107, 169, 0), + make_error(TRIO107, 175, 0), ) diff --git a/tests/trio102.py b/tests/trio102.py index 35b5a579..db2a9956 100644 --- a/tests/trio102.py +++ b/tests/trio102.py @@ -122,3 +122,39 @@ async def foo3(): with trio.fail_after(5), trio.move_on_after(30) as s: s.shield = True await foo() # error: safe in theory, but we don't bother parsing + + +# New: except cancelled/baseexception are also critical +async def foo4(): + await foo() # avoid TRIO107 + try: + ... + except ValueError: + await foo() # safe + except trio.Cancelled: + await foo() # error + raise # avoid TRIO103 + except BaseException: + await foo() # error + raise # avoid TRIO103 + except: + await foo() # error + raise # avoid TRIO103 + + +async def foo5(): + await foo() # avoid TRIO107 + try: + ... + except trio.Cancelled: + with trio.CancelScope(deadline=30, shield=True): + await foo() # safe + raise # avoid TRIO103 + except BaseException: + with trio.CancelScope(deadline=30, shield=True): + await foo() # safe + raise # avoid TRIO103 + except: + with trio.CancelScope(deadline=30, shield=True): + await foo() # safe + raise # avoid TRIO103 diff --git a/tests/trio103_104.py b/tests/trio103_104.py index f6b49a4c..610f4bb6 100644 --- a/tests/trio103_104.py +++ b/tests/trio103_104.py @@ -172,3 +172,26 @@ def foo(): while True: continue raise + +# check for avoiding re-raise by yielding from function +def foo_yield(): + if True: # for code coverage + yield 1 + + try: + pass + except BaseException: + yield 1 # error + raise + + # check that we properly iterate over all nodes in try + except BaseException: + try: + yield 1 # error + except ValueError: + yield 1 # error + else: + yield 1 # error + finally: + yield 1 # error + raise diff --git a/tests/trio107_108.py b/tests/trio107_108.py index 0b108254..c46ea84f 100644 --- a/tests/trio107_108.py +++ b/tests/trio107_108.py @@ -107,7 +107,8 @@ async def foo_try_3(): # safe try: await foo() except ValueError: - await foo() + with trio.CancelScope(deadline=30, shield=True): # avoid TRIO102 + await foo() except: raise