From 221ee28eae3628a7e1e9e81b2309eb8a8fed9eff Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sat, 30 Jul 2022 13:09:04 +0200 Subject: [PATCH 1/9] Added #6: Async functions must have checkpoints on every code path --- CHANGELOG.md | 1 + README.md | 1 + flake8_trio.py | 75 ++++++++++++++++++++++++- tests/test_flake8_trio.py | 17 +++++- tests/trio100_py39.py | 1 + tests/trio102.py | 5 +- tests/trio300.py | 114 ++++++++++++++++++++++++++++++++++++++ tox.ini | 2 +- 8 files changed, 211 insertions(+), 5 deletions(-) create mode 100644 tests/trio300.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c39c847f..dd6ec15e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ *[CalVer, YY.month.patch](https://calver.org/)* ## Future +- Added TRIOXXX check: Async functions must have at least one checkpoint on every code path, unless an exception is raised - Add TRIO103: `except BaseException` or `except trio.Cancelled` with a code path that doesn't re-raise - Add TRIO104: "Cancelled and BaseException must be re-raised" if user tries to return or raise a different exception. diff --git a/README.md b/README.md index 2252d352..542d4697 100644 --- a/README.md +++ b/README.md @@ -28,3 +28,4 @@ pip install flake8-trio - **TRIO104**: `Cancelled` and `BaseException` must be re-raised - when a user tries to `return` or `raise` a different exception. - **TRIO105**: Calling a trio async function without immediately `await`ing it. - **TRIO106**: trio must be imported with `import trio` for the linter to work +- **TRIO300**: Async functions must have at least one checkpoint on every code path, unless an exception is raised diff --git a/flake8_trio.py b/flake8_trio.py index e5982548..aa355b76 100644 --- a/flake8_trio.py +++ b/flake8_trio.py @@ -11,7 +11,17 @@ import ast import tokenize -from typing import Any, Collection, Generator, List, Optional, Tuple, Type, Union +from typing import ( + Any, + Collection, + Generator, + Iterable, + List, + Optional, + Tuple, + Type, + Union, +) # CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1" __version__ = "22.7.4" @@ -47,6 +57,13 @@ def run(cls, tree: ast.AST) -> Generator[Error, None, None]: visitor.visit(tree) yield from visitor.problems + def visit_nodes(self, nodes: Iterable[ast.AST]) -> None: + for node in nodes: + self.visit(node) + + def error(self, error: str, lineno: int, col: int, *args: Any, **kwargs: Any): + self.problems.append(make_error(error, lineno, col, *args, **kwargs)) + class TrioScope: def __init__(self, node: ast.Call, funcname: str, packagename: str): @@ -462,6 +479,61 @@ def visit_Call(self, node: ast.Call): self.generic_visit(node) +class Visitor300(Flake8TrioVisitor): + def __init__(self) -> None: + super().__init__() + self.all_await = False + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): + outer = self.all_await + + self.all_await = False + self.generic_visit(node) + + if not self.all_await: + self.error(TRIO300, node.lineno, node.col_offset) + + self.all_await = outer + + def visit_Await( + self, node: Union[ast.Await, ast.AsyncFor, ast.AsyncWith, ast.Raise] + ): + self.generic_visit(node) + self.all_await = True + + visit_AsyncFor = visit_Await + visit_AsyncWith = visit_Await + visit_Raise = visit_Await + + def visit_Try(self, node: ast.Try): + self.visit_nodes(node.body) + + # disregard await's in excepts + outer = self.all_await + self.visit_nodes(node.handlers) + self.all_await = outer + + self.visit_nodes(node.finalbody) + + def visit_If(self, node: ast.If): + if self.all_await: + self.generic_visit(node) + return + self.visit_nodes(node.body) + body_await = self.all_await + self.all_await = False + + self.visit_nodes(node.orelse) + self.all_await = body_await and self.all_await + + def visit_While(self, node: Union[ast.While, ast.For]): + outer = self.all_await + self.generic_visit(node) + self.all_await = outer + + visit_For = visit_While + + class Plugin: name = __name__ version = __version__ @@ -487,3 +559,4 @@ def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]: TRIO104 = "TRIO104: Cancelled (and therefore BaseException) must be re-raised" TRIO105 = "TRIO105: Trio async function {} must be immediately awaited" TRIO106 = "TRIO106: trio must be imported with `import trio` for the linter to work" +TRIO300 = "TRIO300: Async functions must have at least one checkpoint on every code path, unless an exception is raised" diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py index a6187ab7..816180e0 100644 --- a/tests/test_flake8_trio.py +++ b/tests/test_flake8_trio.py @@ -20,6 +20,7 @@ TRIO104, TRIO105, TRIO106, + TRIO300, Error, Plugin, make_error, @@ -94,7 +95,7 @@ def test_trio102(self): make_error(TRIO102, 92, 8), make_error(TRIO102, 94, 8), make_error(TRIO102, 101, 12), - make_error(TRIO102, 123, 12), + make_error(TRIO102, 124, 12), ) def test_trio103_104(self): @@ -173,6 +174,20 @@ def test_trio106(self): make_error(TRIO106, 6, 0), ) + def test_trio300(self): + self.assert_expected_errors( + "trio300.py", + make_error(TRIO300, 10, 0), + make_error(TRIO300, 15, 0), + make_error(TRIO300, 28, 0), + make_error(TRIO300, 33, 0), + make_error(TRIO300, 46, 0), + make_error(TRIO300, 51, 0), + make_error(TRIO300, 59, 0), + make_error(TRIO300, 90, 0), + make_error(TRIO300, 99, 0), + ) + @pytest.mark.fuzz class TestFuzz(unittest.TestCase): diff --git a/tests/trio100_py39.py b/tests/trio100_py39.py index a033e63b..97290dda 100644 --- a/tests/trio100_py39.py +++ b/tests/trio100_py39.py @@ -14,3 +14,4 @@ async def function_name(): trio.move_on_after(5), # error ): pass + await function_name() # avoid TRIO300 diff --git a/tests/trio102.py b/tests/trio102.py index 2111f28f..2f7545dd 100644 --- a/tests/trio102.py +++ b/tests/trio102.py @@ -5,7 +5,7 @@ async def foo(): try: - pass + await foo() # avoid TRIO300 finally: with trio.move_on_after(deadline=30) as s: s.shield = True @@ -107,11 +107,12 @@ async def foo2(): yield 1 finally: await foo() # safe + await foo() # avoid TRIO300 async def foo3(): try: - pass + await foo() # avoid TRIO300 finally: with trio.move_on_after(30) as s, trio.fail_after(5): s.shield = True diff --git a/tests/trio300.py b/tests/trio300.py new file mode 100644 index 00000000..04fdd506 --- /dev/null +++ b/tests/trio300.py @@ -0,0 +1,114 @@ +import trio + +_ = "" + + +async def foo(): + await foo() + + +async def foo2(): # error + ... + + +# If +async def foo_if_1(): # error + if _: + await foo() + + +async def foo_if_2(): + if _: + await foo() + else: + await foo() + + +# loops +async def foo_while_1(): # error + while _: + await foo() + + +async def foo_while_2(): # error: due to not wanting to handle continue/break semantics + while _: + await foo() + else: + await foo() + + +async def foo_while_3(): # safe + await foo() + while _: + ... + + +async def foo_for_1(): # error + for __ in _: + await foo() + + +async def foo_for_2(): # error: due to not wanting to handle continue/break semantics + for __ in _: + await foo() + else: + await foo() + + +# try +async def foo_try_1(): # error + try: + ... + except ValueError: + await foo() + except: + await foo() + + +async def foo_try_2(): # safe + try: + await foo() + except ValueError: + await foo() + except: + await foo() + + +async def foo_try_3(): # safe + try: + await foo() + except ValueError: + await foo() + except: + await foo() + finally: + with trio.CancelScope(deadline=30, shield=True): # avoid TRIO102 + await foo() + + +# early return +async def foo_return_1(): # error + return + + +async def foo_return_2(): # safe + await foo() + return + + +async def foo_return_3(): # error + if _: + await foo() + return + + +# raise +async def foo_raise_1(): # safe + raise ValueError() + + +async def foo_raise_2(): # safe + if _: + await foo() + else: + raise ValueError() diff --git a/tox.ini b/tox.ini index 3f35f024..92161f76 100644 --- a/tox.ini +++ b/tox.ini @@ -28,7 +28,7 @@ ignore_errors = commands = shed flake8 --exclude .*,tests/trio*.py - pyright --pythonversion 3.10 + pyright --pythonversion 3.10 --warnings # generate py38-test py39-test and test [testenv:{py38-, py39-,}test] From 3a72ee5be9f83193b9efab5462d277b9c388ebea Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sat, 30 Jul 2022 14:17:36 +0200 Subject: [PATCH 2/9] return, nested functions, ignore checkpoints in try, inline if --- CHANGELOG.md | 3 +- README.md | 3 ++ flake8_trio.py | 48 ++++++++++++++++++++----- tests/test_flake8_trio.py | 21 +++++++---- tests/trio300.py | 73 +++++++++++++++++++++++++++++++-------- 5 files changed, 119 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dd6ec15e..e117c511 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,8 @@ *[CalVer, YY.month.patch](https://calver.org/)* ## Future -- Added TRIOXXX check: Async functions must have at least one checkpoint on every code path, unless an exception is raised +- Added TRIO300: Async functions must have at least one checkpoint on every code path, unless an exception is raised +- Added TRIO301: Early return from async function must have at least one checkpoint on every code path before it. - Add TRIO103: `except BaseException` or `except trio.Cancelled` with a code path that doesn't re-raise - Add TRIO104: "Cancelled and BaseException must be re-raised" if user tries to return or raise a different exception. diff --git a/README.md b/README.md index 542d4697..2a523d92 100644 --- a/README.md +++ b/README.md @@ -28,4 +28,7 @@ pip install flake8-trio - **TRIO104**: `Cancelled` and `BaseException` must be re-raised - when a user tries to `return` or `raise` a different exception. - **TRIO105**: Calling a trio async function without immediately `await`ing it. - **TRIO106**: trio must be imported with `import trio` for the linter to work +- - **TRIO300**: Async functions must have at least one checkpoint on every code path, unless an exception is raised +- **TRIO301**: Early return from async function must have at least one checkpoint on every code path before it, unless an exception is raised. +Checkpoints are `await`, `async with` `async for`. diff --git a/flake8_trio.py b/flake8_trio.py index aa355b76..86bac433 100644 --- a/flake8_trio.py +++ b/flake8_trio.py @@ -57,9 +57,12 @@ def run(cls, tree: ast.AST) -> Generator[Error, None, None]: visitor.visit(tree) yield from visitor.problems - def visit_nodes(self, nodes: Iterable[ast.AST]) -> None: - for node in nodes: - self.visit(node) + def visit_nodes(self, nodes: Union[ast.expr, Iterable[ast.AST]]) -> None: + if isinstance(nodes, ast.expr): + self.visit(nodes) + else: + for node in nodes: + self.visit(node) def error(self, error: str, lineno: int, col: int, *args: Any, **kwargs: Any): self.problems.append(make_error(error, lineno, col, *args, **kwargs)) @@ -482,7 +485,7 @@ def visit_Call(self, node: ast.Call): class Visitor300(Flake8TrioVisitor): def __init__(self) -> None: super().__init__() - self.all_await = False + self.all_await = True def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): outer = self.all_await @@ -495,6 +498,20 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): self.all_await = outer + def visit_Return(self, node: ast.Return): + self.generic_visit(node) + if not self.all_await: + self.error(TRIO301, node.lineno, node.col_offset) + # avoid duplicate error messages + self.all_await = True + + # disregard raise's in nested functions + def visit_FunctionDef(self, node: ast.FunctionDef): + outer = self.all_await + self.generic_visit(node) + self.all_await = outer + + # checkpoint functions def visit_Await( self, node: Union[ast.Await, ast.AsyncFor, ast.AsyncWith, ast.Raise] ): @@ -503,22 +520,32 @@ def visit_Await( visit_AsyncFor = visit_Await visit_AsyncWith = visit_Await + + # raising exception means we don't need to checkpoint so we can treat it as one visit_Raise = visit_Await + # ignore checkpoints in try, excepts and orelse def visit_Try(self, node: ast.Try): - self.visit_nodes(node.body) - - # disregard await's in excepts outer = self.all_await + + self.visit_nodes(node.body) self.visit_nodes(node.handlers) + self.visit_nodes(node.orelse) + self.all_await = outer self.visit_nodes(node.finalbody) - def visit_If(self, node: ast.If): + # valid checkpoint if both body and orelse have checkpoints + def visit_If(self, node: Union[ast.If, ast.IfExp]): if self.all_await: self.generic_visit(node) return + + # ignore checkpoints in condition + self.visit_nodes(node.test) + self.all_await = False + self.visit_nodes(node.body) body_await = self.all_await self.all_await = False @@ -526,6 +553,10 @@ def visit_If(self, node: ast.If): self.visit_nodes(node.orelse) self.all_await = body_await and self.all_await + # inline if + visit_IfExp = visit_If + + # ignore checkpoints in loops due to continue/break shenanigans def visit_While(self, node: Union[ast.While, ast.For]): outer = self.all_await self.generic_visit(node) @@ -560,3 +591,4 @@ def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]: TRIO105 = "TRIO105: Trio async function {} must be immediately awaited" TRIO106 = "TRIO106: trio must be imported with `import trio` for the linter to work" TRIO300 = "TRIO300: Async functions must have at least one checkpoint on every code path, unless an exception is raised" +TRIO301 = "TRIO301: Early return from async function must have at least one checkpoint on every code path before it." diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py index 816180e0..2222042b 100644 --- a/tests/test_flake8_trio.py +++ b/tests/test_flake8_trio.py @@ -21,6 +21,7 @@ TRIO105, TRIO106, TRIO300, + TRIO301, Error, Plugin, make_error, @@ -179,13 +180,21 @@ def test_trio300(self): "trio300.py", make_error(TRIO300, 10, 0), make_error(TRIO300, 15, 0), - make_error(TRIO300, 28, 0), make_error(TRIO300, 33, 0), - make_error(TRIO300, 46, 0), - make_error(TRIO300, 51, 0), - make_error(TRIO300, 59, 0), - make_error(TRIO300, 90, 0), - make_error(TRIO300, 99, 0), + make_error(TRIO300, 43, 0), + make_error(TRIO300, 48, 0), + make_error(TRIO300, 53, 0), + make_error(TRIO300, 66, 0), + make_error(TRIO300, 71, 0), + make_error(TRIO300, 79, 0), + make_error(TRIO301, 102, 4), + make_error(TRIO301, 107, 8), + make_error(TRIO300, 111, 0), + make_error(TRIO300, 133, 4), + make_error(TRIO300, 137, 0), + make_error(TRIO300, 144, 8), + make_error(TRIO300, 142, 0), + make_error(TRIO300, 148, 0), ) diff --git a/tests/trio300.py b/tests/trio300.py index 04fdd506..4e5e066b 100644 --- a/tests/trio300.py +++ b/tests/trio300.py @@ -24,6 +24,26 @@ async def foo_if_2(): await foo() +async def foo_if_3(): + await foo() + if _: + ... + + +async def foo_if_4(): # error + if await foo(): + ... + + +# IfExp +async def foo_ifexp_1(): # safe + print(await foo() if _ else await foo()) + + +async def foo_ifexp_2(): # error + print(_ if await foo() else await foo()) + + # loops async def foo_while_1(): # error while _: @@ -57,15 +77,6 @@ async def foo_for_2(): # error: due to not wanting to handle continue/break sem # try async def foo_try_1(): # error - try: - ... - except ValueError: - await foo() - except: - await foo() - - -async def foo_try_2(): # safe try: await foo() except ValueError: @@ -74,7 +85,7 @@ async def foo_try_2(): # safe await foo() -async def foo_try_3(): # safe +async def foo_try_2(): # safe try: await foo() except ValueError: @@ -87,19 +98,20 @@ async def foo_try_3(): # safe # early return -async def foo_return_1(): # error - return +async def foo_return_1(): # silent to avoid duplicate errors + return # error async def foo_return_2(): # safe + if _: + return # error await foo() - return async def foo_return_3(): # error if _: await foo() - return + return # safe # raise @@ -112,3 +124,36 @@ async def foo_raise_2(): # safe await foo() else: raise ValueError() + + +# nested function definition +async def foo_func_1(): + await foo() + + async def foo_func_2(): # error + ... + + +async def foo_func_3(): # error + async def foo_func_4(): + await foo() + + +async def foo_func_5(): # error + def foo_func_6(): # safe + async def foo_func_7(): # error + ... + + +async def foo_func_8(): # error + def foo_func_9(): + raise + + +# normal function +def foo_normal_func_1(): + return + + +def foo_normal_func_2(): + ... From bfa299943976d3cb89417a2754891db8dd530fb8 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sat, 30 Jul 2022 14:20:48 +0200 Subject: [PATCH 3/9] 300 -> 300_301 --- flake8_trio.py | 2 +- tests/test_flake8_trio.py | 4 ++-- tests/{trio300.py => trio300_301.py} | 0 3 files changed, 3 insertions(+), 3 deletions(-) rename tests/{trio300.py => trio300_301.py} (100%) diff --git a/flake8_trio.py b/flake8_trio.py index 86bac433..15f2b31b 100644 --- a/flake8_trio.py +++ b/flake8_trio.py @@ -482,7 +482,7 @@ def visit_Call(self, node: ast.Call): self.generic_visit(node) -class Visitor300(Flake8TrioVisitor): +class Visitor300_301(Flake8TrioVisitor): def __init__(self) -> None: super().__init__() self.all_await = True diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py index 2222042b..8350e32f 100644 --- a/tests/test_flake8_trio.py +++ b/tests/test_flake8_trio.py @@ -175,9 +175,9 @@ def test_trio106(self): make_error(TRIO106, 6, 0), ) - def test_trio300(self): + def test_trio300_301(self): self.assert_expected_errors( - "trio300.py", + "trio300_301.py", make_error(TRIO300, 10, 0), make_error(TRIO300, 15, 0), make_error(TRIO300, 33, 0), diff --git a/tests/trio300.py b/tests/trio300_301.py similarity index 100% rename from tests/trio300.py rename to tests/trio300_301.py From d2a31c693003fd711751b81d28acae5b9ad77c6d Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sat, 30 Jul 2022 14:45:55 +0200 Subject: [PATCH 4/9] handle case of (try or else) and all(except), reorder tests --- flake8_trio.py | 32 ++++++++++++++++++----- tests/test_flake8_trio.py | 22 ++++++++++------ tests/test_trio_tests.py | 4 ++- tests/trio300_301.py | 54 +++++++++++++++++++++++++++------------ 4 files changed, 81 insertions(+), 31 deletions(-) diff --git a/flake8_trio.py b/flake8_trio.py index 15f2b31b..ed834208 100644 --- a/flake8_trio.py +++ b/flake8_trio.py @@ -57,8 +57,8 @@ def run(cls, tree: ast.AST) -> Generator[Error, None, None]: visitor.visit(tree) yield from visitor.problems - def visit_nodes(self, nodes: Union[ast.expr, Iterable[ast.AST]]) -> None: - if isinstance(nodes, ast.expr): + def visit_nodes(self, nodes: Union[ast.AST, Iterable[ast.AST]]) -> None: + if isinstance(nodes, ast.AST): self.visit(nodes) else: for node in nodes: @@ -524,16 +524,33 @@ def visit_Await( # raising exception means we don't need to checkpoint so we can treat it as one visit_Raise = visit_Await - # ignore checkpoints in try, excepts and orelse + # valid checkpoint if there's valid checkpoints (or raise) in at least one of: + # (try or else) and all excepts + # finally def visit_Try(self, node: ast.Try): - outer = self.all_await + if self.all_await: + self.generic_visit(node) + return + # check try body self.visit_nodes(node.body) - self.visit_nodes(node.handlers) + body_await = self.all_await + self.all_await = False + + # check that all except handlers checkpoint (await or most likely raise) + all_except_await = True + for handler in node.handlers: + self.visit_nodes(handler) + all_except_await &= self.all_await + self.all_await = False + + # check else self.visit_nodes(node.orelse) - self.all_await = outer + # (try or else) and all excepts + self.all_await = (body_await or self.all_await) and all_except_await + # finally can check on it's own self.visit_nodes(node.finalbody) # valid checkpoint if both body and orelse have checkpoints @@ -546,11 +563,14 @@ def visit_If(self, node: Union[ast.If, ast.IfExp]): self.visit_nodes(node.test) self.all_await = False + # check body self.visit_nodes(node.body) body_await = self.all_await self.all_await = False self.visit_nodes(node.orelse) + + # checkpoint if both body and else self.all_await = body_await and self.all_await # inline if diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py index 8350e32f..482efac2 100644 --- a/tests/test_flake8_trio.py +++ b/tests/test_flake8_trio.py @@ -179,22 +179,28 @@ def test_trio300_301(self): self.assert_expected_errors( "trio300_301.py", make_error(TRIO300, 10, 0), + # if make_error(TRIO300, 15, 0), make_error(TRIO300, 33, 0), + # ifexp make_error(TRIO300, 43, 0), + # loops make_error(TRIO300, 48, 0), make_error(TRIO300, 53, 0), make_error(TRIO300, 66, 0), make_error(TRIO300, 71, 0), + # try make_error(TRIO300, 79, 0), - make_error(TRIO301, 102, 4), - make_error(TRIO301, 107, 8), - make_error(TRIO300, 111, 0), - make_error(TRIO300, 133, 4), - make_error(TRIO300, 137, 0), - make_error(TRIO300, 144, 8), - make_error(TRIO300, 142, 0), - make_error(TRIO300, 148, 0), + # early return + make_error(TRIO301, 136, 4), + make_error(TRIO301, 141, 8), + # nested function definition + make_error(TRIO300, 145, 0), + make_error(TRIO300, 155, 4), + make_error(TRIO300, 159, 0), + make_error(TRIO300, 166, 8), + make_error(TRIO300, 164, 0), + make_error(TRIO300, 170, 0), ) diff --git a/tests/test_trio_tests.py b/tests/test_trio_tests.py index 580bc304..e762b8e5 100644 --- a/tests/test_trio_tests.py +++ b/tests/test_trio_tests.py @@ -49,4 +49,6 @@ def runTest(self): self.assertNotIn(lineno, func_error_lines, msg=test) func_error_lines.add(lineno) - self.assertSetEqual(file_error_lines, func_error_lines, msg=test) + self.assertSequenceEqual( + sorted(file_error_lines), sorted(func_error_lines), msg=test + ) diff --git a/tests/trio300_301.py b/tests/trio300_301.py index 4e5e066b..b0f1e96c 100644 --- a/tests/trio300_301.py +++ b/tests/trio300_301.py @@ -80,23 +80,57 @@ async def foo_try_1(): # error try: await foo() except ValueError: - await foo() + ... except: await foo() + else: + await foo() async def foo_try_2(): # safe try: - await foo() + ... except ValueError: - await foo() + ... except: - await foo() + ... finally: with trio.CancelScope(deadline=30, shield=True): # avoid TRIO102 await foo() +async def foo_try_3(): # safe + try: + await foo() + except ValueError: + await foo() + except: + await foo() + + +# raise +async def foo_raise_1(): # safe + raise ValueError() + + +async def foo_raise_2(): # safe + if _: + await foo() + else: + raise ValueError() + + +async def foo_try_4(): # safe + try: + ... + except ValueError: + raise + except: + raise + else: + await foo() + + # early return async def foo_return_1(): # silent to avoid duplicate errors return # error @@ -114,18 +148,6 @@ async def foo_return_3(): # error return # safe -# raise -async def foo_raise_1(): # safe - raise ValueError() - - -async def foo_raise_2(): # safe - if _: - await foo() - else: - raise ValueError() - - # nested function definition async def foo_func_1(): await foo() From 2a6ac5b3e536e55799bf7f4bb9f38247f3631042 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sat, 30 Jul 2022 21:11:11 +0200 Subject: [PATCH 5/9] added comment, overloaded functions are safe --- flake8_trio.py | 21 ++++++--------------- tests/test_flake8_trio.py | 34 +++++++++++++++++----------------- tests/trio300_301.py | 21 ++++++++++++++++++++- 3 files changed, 43 insertions(+), 33 deletions(-) diff --git a/flake8_trio.py b/flake8_trio.py index ed834208..dd7a276a 100644 --- a/flake8_trio.py +++ b/flake8_trio.py @@ -11,17 +11,7 @@ import ast import tokenize -from typing import ( - Any, - Collection, - Generator, - Iterable, - List, - Optional, - Tuple, - Type, - Union, -) +from typing import Any, Generator, Iterable, List, Optional, Tuple, Type, Union # CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1" __version__ = "22.7.4" @@ -108,7 +98,7 @@ def get_trio_scope(node: ast.AST, *names: str) -> Optional[TrioScope]: return None -def has_decorator(decorator_list: List[ast.expr], names: Collection[str]): +def has_decorator(decorator_list: List[ast.expr], *names: str): for dec in decorator_list: if (isinstance(dec, ast.Name) and dec.id in names) or ( isinstance(dec, ast.Attribute) and dec.attr in names @@ -155,7 +145,7 @@ def visit_FunctionDef( self._yield_is_error = False # check for @ and @. - if has_decorator(node.decorator_list, context_manager_names): + if has_decorator(node.decorator_list, *context_manager_names): self._context_manager = True self.generic_visit(node) @@ -258,7 +248,7 @@ def visit_FunctionDef( outer_cm = self._context_manager # check for @ and @. - if has_decorator(node.decorator_list, context_manager_names): + if has_decorator(node.decorator_list, *context_manager_names): self._context_manager = True self.generic_visit(node) @@ -490,7 +480,8 @@ def __init__(self) -> None: def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): outer = self.all_await - self.all_await = False + # do not require checkpointing if overloading + self.all_await = has_decorator(node.decorator_list, "overload") self.generic_visit(node) if not self.all_await: diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py index 482efac2..d8ecea5d 100644 --- a/tests/test_flake8_trio.py +++ b/tests/test_flake8_trio.py @@ -178,29 +178,29 @@ def test_trio106(self): def test_trio300_301(self): self.assert_expected_errors( "trio300_301.py", - make_error(TRIO300, 10, 0), + make_error(TRIO300, 13, 0), # if - make_error(TRIO300, 15, 0), - make_error(TRIO300, 33, 0), + make_error(TRIO300, 18, 0), + make_error(TRIO300, 36, 0), # ifexp - make_error(TRIO300, 43, 0), + make_error(TRIO300, 46, 0), # loops - make_error(TRIO300, 48, 0), - make_error(TRIO300, 53, 0), - make_error(TRIO300, 66, 0), - make_error(TRIO300, 71, 0), + make_error(TRIO300, 51, 0), + make_error(TRIO300, 56, 0), + make_error(TRIO300, 69, 0), + make_error(TRIO300, 74, 0), # try - make_error(TRIO300, 79, 0), + make_error(TRIO300, 83, 0), # early return - make_error(TRIO301, 136, 4), - make_error(TRIO301, 141, 8), + make_error(TRIO301, 140, 4), + make_error(TRIO301, 145, 8), # nested function definition - make_error(TRIO300, 145, 0), - make_error(TRIO300, 155, 4), - make_error(TRIO300, 159, 0), - make_error(TRIO300, 166, 8), - make_error(TRIO300, 164, 0), - make_error(TRIO300, 170, 0), + make_error(TRIO300, 149, 0), + make_error(TRIO300, 159, 4), + make_error(TRIO300, 163, 0), + make_error(TRIO300, 170, 8), + make_error(TRIO300, 168, 0), + make_error(TRIO300, 174, 0), ) diff --git a/tests/trio300_301.py b/tests/trio300_301.py index b0f1e96c..29c5c714 100644 --- a/tests/trio300_301.py +++ b/tests/trio300_301.py @@ -1,3 +1,6 @@ +import typing +from typing import Union, overload + import trio _ = "" @@ -76,7 +79,8 @@ async def foo_for_2(): # error: due to not wanting to handle continue/break sem # try -async def foo_try_1(): # error +# safe only if (try or else) and all except bodies either await or raise +async def foo_try_1(): # error: if foo() raises a ValueError it's not checkpointed try: await foo() except ValueError: @@ -179,3 +183,18 @@ def foo_normal_func_1(): def foo_normal_func_2(): ... + + +# overload decorator +@overload +async def foo_overload_1(_: bytes): + ... + + +@typing.overload +async def foo_overload_1(_: str): + ... + + +async def foo_overload_1(_: Union[bytes, str]): + await foo() From 10a79d0edb622eb50a27a88fef079c9c78ab84d1 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sat, 30 Jul 2022 21:15:16 +0200 Subject: [PATCH 6/9] 300,301 -> 107,108 --- CHANGELOG.md | 4 +-- README.md | 4 +-- flake8_trio.py | 10 +++--- tests/test_flake8_trio.py | 42 ++++++++++++------------ tests/{trio300_301.py => trio107_108.py} | 0 5 files changed, 30 insertions(+), 30 deletions(-) rename tests/{trio300_301.py => trio107_108.py} (100%) diff --git a/CHANGELOG.md b/CHANGELOG.md index e117c511..481f0f5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,8 @@ *[CalVer, YY.month.patch](https://calver.org/)* ## Future -- Added TRIO300: Async functions must have at least one checkpoint on every code path, unless an exception is raised -- Added TRIO301: Early return from async function must have at least one checkpoint on every code path before it. +- Added TRIO107: Async functions must have at least one checkpoint on every code path, unless an exception is raised +- Added TRIO108: Early return from async function must have at least one checkpoint on every code path before it. - Add TRIO103: `except BaseException` or `except trio.Cancelled` with a code path that doesn't re-raise - Add TRIO104: "Cancelled and BaseException must be re-raised" if user tries to return or raise a different exception. diff --git a/README.md b/README.md index 2a523d92..e17a6a8f 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,6 @@ pip install flake8-trio - **TRIO105**: Calling a trio async function without immediately `await`ing it. - **TRIO106**: trio must be imported with `import trio` for the linter to work - -- **TRIO300**: Async functions must have at least one checkpoint on every code path, unless an exception is raised -- **TRIO301**: Early return from async function must have at least one checkpoint on every code path before it, unless an exception is raised. +- **TRIO107**: Async functions must have at least one checkpoint on every code path, unless an exception is raised +- **TRIO108**: Early return from async function must have at least one checkpoint on every code path before it, unless an exception is raised. Checkpoints are `await`, `async with` `async for`. diff --git a/flake8_trio.py b/flake8_trio.py index dd7a276a..38182df5 100644 --- a/flake8_trio.py +++ b/flake8_trio.py @@ -472,7 +472,7 @@ def visit_Call(self, node: ast.Call): self.generic_visit(node) -class Visitor300_301(Flake8TrioVisitor): +class Visitor107_108(Flake8TrioVisitor): def __init__(self) -> None: super().__init__() self.all_await = True @@ -485,14 +485,14 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): self.generic_visit(node) if not self.all_await: - self.error(TRIO300, node.lineno, node.col_offset) + self.error(TRIO107, node.lineno, node.col_offset) self.all_await = outer def visit_Return(self, node: ast.Return): self.generic_visit(node) if not self.all_await: - self.error(TRIO301, node.lineno, node.col_offset) + self.error(TRIO108, node.lineno, node.col_offset) # avoid duplicate error messages self.all_await = True @@ -601,5 +601,5 @@ def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]: TRIO104 = "TRIO104: Cancelled (and therefore BaseException) must be re-raised" TRIO105 = "TRIO105: Trio async function {} must be immediately awaited" TRIO106 = "TRIO106: trio must be imported with `import trio` for the linter to work" -TRIO300 = "TRIO300: Async functions must have at least one checkpoint on every code path, unless an exception is raised" -TRIO301 = "TRIO301: Early return from async function must have at least one checkpoint on every code path before it." +TRIO107 = "TRIO107: Async functions must have at least one checkpoint on every code path, unless an exception is raised" +TRIO108 = "TRIO108: Early return from async function must have at least one checkpoint on every code path before it." diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py index d8ecea5d..05c564a6 100644 --- a/tests/test_flake8_trio.py +++ b/tests/test_flake8_trio.py @@ -20,8 +20,8 @@ TRIO104, TRIO105, TRIO106, - TRIO300, - TRIO301, + TRIO107, + TRIO108, Error, Plugin, make_error, @@ -175,32 +175,32 @@ def test_trio106(self): make_error(TRIO106, 6, 0), ) - def test_trio300_301(self): + def test_trio107_108(self): self.assert_expected_errors( - "trio300_301.py", - make_error(TRIO300, 13, 0), + "trio107_108.py", + make_error(TRIO107, 13, 0), # if - make_error(TRIO300, 18, 0), - make_error(TRIO300, 36, 0), + make_error(TRIO107, 18, 0), + make_error(TRIO107, 36, 0), # ifexp - make_error(TRIO300, 46, 0), + make_error(TRIO107, 46, 0), # loops - make_error(TRIO300, 51, 0), - make_error(TRIO300, 56, 0), - make_error(TRIO300, 69, 0), - make_error(TRIO300, 74, 0), + make_error(TRIO107, 51, 0), + make_error(TRIO107, 56, 0), + make_error(TRIO107, 69, 0), + make_error(TRIO107, 74, 0), # try - make_error(TRIO300, 83, 0), + make_error(TRIO107, 83, 0), # early return - make_error(TRIO301, 140, 4), - make_error(TRIO301, 145, 8), + make_error(TRIO108, 140, 4), + make_error(TRIO108, 145, 8), # nested function definition - make_error(TRIO300, 149, 0), - make_error(TRIO300, 159, 4), - make_error(TRIO300, 163, 0), - make_error(TRIO300, 170, 8), - make_error(TRIO300, 168, 0), - make_error(TRIO300, 174, 0), + make_error(TRIO107, 149, 0), + make_error(TRIO107, 159, 4), + make_error(TRIO107, 163, 0), + make_error(TRIO107, 170, 8), + make_error(TRIO107, 168, 0), + make_error(TRIO107, 174, 0), ) diff --git a/tests/trio300_301.py b/tests/trio107_108.py similarity index 100% rename from tests/trio300_301.py rename to tests/trio107_108.py From 6ad9d9670273c6b1c6a905edfbbf4f5ab1ca71ff Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sat, 30 Jul 2022 21:26:45 +0200 Subject: [PATCH 7/9] fixed TRIO103 in trio107_108.py --- tests/trio107_108.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/trio107_108.py b/tests/trio107_108.py index 29c5c714..0b108254 100644 --- a/tests/trio107_108.py +++ b/tests/trio107_108.py @@ -86,7 +86,7 @@ async def foo_try_1(): # error: if foo() raises a ValueError it's not checkpoin except ValueError: ... except: - await foo() + raise else: await foo() @@ -97,7 +97,7 @@ async def foo_try_2(): # safe except ValueError: ... except: - ... + raise finally: with trio.CancelScope(deadline=30, shield=True): # avoid TRIO102 await foo() @@ -109,7 +109,7 @@ async def foo_try_3(): # safe except ValueError: await foo() except: - await foo() + raise # raise From 1780e3c3966ed011e8f2d864ab7bc13b314d29e1 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sat, 30 Jul 2022 21:29:08 +0200 Subject: [PATCH 8/9] incremented version --- CHANGELOG.md | 6 +++--- flake8_trio.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 481f0f5d..007454e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,11 @@ # Changelog *[CalVer, YY.month.patch](https://calver.org/)* -## Future -- Added TRIO107: Async functions must have at least one checkpoint on every code path, unless an exception is raised -- Added TRIO108: Early return from async function must have at least one checkpoint on every code path before it. +## 22.7.5 - Add TRIO103: `except BaseException` or `except trio.Cancelled` with a code path that doesn't re-raise - Add TRIO104: "Cancelled and BaseException must be re-raised" if user tries to return or raise a different exception. +- Added TRIO107: Async functions must have at least one checkpoint on every code path, unless an exception is raised +- Added TRIO108: Early return from async function must have at least one checkpoint on every code path before it. ## 22.7.4 - Added TRIO105 check for not immediately `await`ing async trio functions. diff --git a/flake8_trio.py b/flake8_trio.py index 38182df5..af3885aa 100644 --- a/flake8_trio.py +++ b/flake8_trio.py @@ -14,7 +14,7 @@ from typing import Any, Generator, Iterable, List, Optional, Tuple, Type, Union # CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1" -__version__ = "22.7.4" +__version__ = "22.7.5" Error = Tuple[int, int, str, Type[Any]] From 81ebc1cc83ead69a81fa81373fe888928df7c640 Mon Sep 17 00:00:00 2001 From: Zac Hatfield-Dodds Date: Sun, 31 Jul 2022 01:44:21 -0700 Subject: [PATCH 9/9] Update comments --- tests/trio100_py39.py | 2 +- tests/trio102.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/trio100_py39.py b/tests/trio100_py39.py index 97290dda..95ed9424 100644 --- a/tests/trio100_py39.py +++ b/tests/trio100_py39.py @@ -14,4 +14,4 @@ async def function_name(): trio.move_on_after(5), # error ): pass - await function_name() # avoid TRIO300 + await function_name() # avoid TRIO107 diff --git a/tests/trio102.py b/tests/trio102.py index 2f7545dd..35b5a579 100644 --- a/tests/trio102.py +++ b/tests/trio102.py @@ -5,7 +5,7 @@ async def foo(): try: - await foo() # avoid TRIO300 + await foo() # avoid TRIO107 finally: with trio.move_on_after(deadline=30) as s: s.shield = True @@ -107,12 +107,12 @@ async def foo2(): yield 1 finally: await foo() # safe - await foo() # avoid TRIO300 + await foo() # avoid TRIO107 async def foo3(): try: - await foo() # avoid TRIO300 + await foo() # avoid TRIO107 finally: with trio.move_on_after(30) as s, trio.fail_after(5): s.shield = True