8000 Trio102: await inside finally needs shielding and timeout by jakkdl · Pull Request #5 · python-trio/flake8-async · GitHub
[go: up one dir, main page]

Skip to content

Trio102: await inside finally needs shielding and timeout #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Changelog
*[CalVer, YY.month.patch](https://calver.org/)*

## 22.7.3
- Added TRIO102 check for unsafe checkpoints inside `finally:` blocks

## 22.7.2
- Avoid `TRIO100` false-alarms on cancel scopes containing `async for` or `async with`.

## 22.7.1
- Initial release
- Initial release with TRIO100 and TRIO101
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ pip install flake8-trio
context does not contain any `await` statements. This makes it pointless, as
the timeout can only be triggered by a checkpoint.
- **TRIO101** `yield` inside a nursery or cancel scope is only safe when implementing a context manager - otherwise, it breaks exception handling.
- **TRIO102** it's unsafe to await inside `finally:` unless you use a shielded
cancel scope with a timeout"
204 changes: 173 additions & 31 deletions flake8_trio.py
9E7A
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

import ast
import tokenize
from typing import Any, Generator, List, Optional, Set, Tuple, Type, Union
from typing import Any, Collection, Generator, List, Optional, Tuple, Type, Union

# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
__version__ = "22.7.2"
__version__ = "22.7.3"


Error = Tuple[int, int, str, Type[Any]]
Expand All @@ -36,66 +36,206 @@ def make_error(error: str, lineno: int, col: int, *args: Any, **kwargs: Any) ->
return (lineno, col, error.format(*args, **kwargs), type(Plugin))


def is_trio_call(node: ast.AST, *names: str) -> Optional[str]:
class TrioScope:
def __init__(self, node: ast.Call, funcname: str, packagename: str):
self.node = node
self.funcname = funcname
self.packagename = packagename
self.variable_name: Optional[str] = None
self.shielded: bool = False
self.has_timeout: bool = False

if self.funcname == "CancelScope":
for kw in node.keywords:
# Only accepts constant values
if kw.arg == "shield" and isinstance(kw.value, ast.Constant):
self.shielded = kw.value.value
# sets to True even if timeout is explicitly set to inf
if kw.arg == "deadline":
self.has_timeout = True
else:
self.has_timeout = True

def __str__(self):
# Not supporting other ways of importing trio
# if self.packagename is None:
# return self.funcname
return f"{self.packagename}.{self.funcname}"


def get_trio_scope(node: ast.AST, *names: str) -> Optional[TrioScope]:
if (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == "trio"
and node.func.attr in names
):
return "trio." + node.func.attr
# return "trio." + node.func.attr
return TrioScope(node, node.func.attr, node.func.value.id)
return None


def has_decorator(decorator_list: List[ast.expr], names: Collection[str]):
for dec in decorator_list:
if (isinstance(dec, ast.Name) and dec.id in names) or (
isinstance(dec, ast.Attribute) and dec.attr in names
):
return True
return False


class Visitor102(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()
self.problems: List[Error] = []
self._inside_finally: bool = False
self._scopes: List[TrioScope] = []
self._context_manager = False

def visit_Assign(self, node: ast.Assign) -> None:
# checks for <scopename>.shield = [True/False]
if self._scopes and len(node.targets) == 1:
last_scope = self._scopes[-1]
target = node.targets[0]
if (
last_scope.variable_name is not None
and isinstance(target, ast.Attribute)
and isinstance(target.value, ast.Name)
and target.value.id == last_scope.variable_name
and target.attr == "shield"
and isinstance(node.value, ast.Constant)
):
last_scope.shielded = node.value.value
self.generic_visit(node)

def visit_Await(self, node: ast.Await) -> None:
self.check_for_trio102(node)
self.generic_visit(node)

def visit_With(self, node: Union[ast.With, ast.AsyncWith]) -> None:
trio_scope = None

# Check for a `with trio.<scope_creater>`
for item in node.items:
trio_scope = get_trio_scope(
item.context_expr, "open_nursery", *cancel_scope_names
)
if trio_scope is not None:
# check if it's saved in a variable
if isinstance(item.optional_vars, ast.Name):
trio_scope.variable_name = item.optional_vars.id
break

if trio_scope is not None:
self._scopes.append(trio_scope)

self.generic_visit(node)

if trio_scope is not None:
self._scopes.pop()

def visit_AsyncWith(self, node: ast.AsyncWith) -> None:
self.check_for_trio102(node)
self.visit_With(node)

def visit_AsyncFor(self, node: ast.AsyncFor) -> None:
self.check_for_trio102(node)
self.generic_visit(node)

def visit_FunctionDef(
self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> None:
outer_cm = self._context_manager

# check for @<context_manager_name> and @<library>.<context_manager_name>
if has_decorator(node.decorator_list, context_manager_names):
self._context_manager = True

self.generic_visit(node)
self._context_manager = outer_cm

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
self.visit_FunctionDef(node)

def visit_Try(self, node: ast.Try) -> None:
# There's no visit_Finally, so we need to manually visit the Try fields.
# It's important to do self.visit instead of self.generic_visit since
# the nodes in the fields might be registered elsewhere in this class.
for item in (*node.body, *node.handlers, *node.orelse):
self.visit(item)

outer = self._inside_finally
outer_scopes = self._scopes

self._scopes = []
self._inside_finally = True

for item in node.finalbody:
self.visit(item)

self._scopes = outer_scopes
self._inside_finally = outer

def check_for_trio102(self, node: Union[ast.Await, ast.AsyncFor, ast.AsyncWith]):
# if we're inside a finally, and not inside a context_manager, and we're not
# inside a scope that doesn't have both a timeout and shield
if (
self._inside_finally
and not self._context_manager
and not any(scope.has_timeout and scope.shielded for scope in self._scopes)
):
self.problems.append(make_error(TRIO102, node.lineno, node.col_offset))


class Visitor(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()
self.problems: List[Error] = []
self.safe_yields: Set[ast.Yield] = set()
self._yield_is_error = False
self._context_manager = False

def visit_generic_with(self, node: Union[ast.With, ast.AsyncWith]):
def visit_With(self, node: Union[ast.With, ast.AsyncWith]) -> None:
self.check_for_trio100(node)

outer = self._yield_is_error
if not self._context_manager and any(
is_trio_call(item, "open_nursery", *cancel_scope_names)
for item in (i.context_expr for i in node.items)
):
self._yield_is_error = True
outer_yie = self._yield_is_error

# Check for a `with trio.<scope_creater>`
if not self._context_manager:
for item in (i.context_expr for i in node.items):
if (
get_trio_scope(item, "open_nursery", *cancel_scope_names)
is not None
):
self._yield_is_error = True
break

self.generic_visit(node)
self._yield_is_error = outer

def visit_With(self, node: ast.With) -> None:
self.visit_generic_with(node)
# reset yield_is_error
self._yield_is_error = outer_yie

def visit_AsyncWith(self, node: ast.AsyncWith) -> None:
self.visit_generic_with(node)
self.visit_With(node)

def visit_generic_FunctionDef(
def visit_FunctionDef(
self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
):
) -> None:
outer_cm = self._context_manager
outer_yie = self._yield_is_error
self._yield_is_error = False
if any(
(isinstance(d, ast.Name) and d.id in context_manager_names)
or (isinstance(d, ast.Attribute) and d.attr in context_manager_names)
for d in node.decorator_list
):

# check for @<context_manager_name> and @<library>.<context_manager_name>
if has_decorator(node.decorator_list, context_manager_names):
self._context_manager = True

self.generic_visit(node)

self._context_manager = outer_cm
self._yield_is_error = outer_yie

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self.visit_generic_FunctionDef(node)

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
self.visit_generic_FunctionDef(node)
self.visit_FunctionDef(node)

def visit_Yield(self, node: ast.Yield) -> None:
if self._yield_is_error:
Expand All @@ -106,7 +246,7 @@ def visit_Yield(self, node: ast.Yield) -> None:
def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]) -> None:
# Context manager with no `await` call within
for item in (i.context_expr for i in node.items):
call = is_trio_call(item, *cancel_scope_names)
call = get_trio_scope(item, *cancel_scope_names)
if call and not any(
isinstance(x, checkpoint_node_types) for x in ast.walk(node)
):
Expand All @@ -129,10 +269,12 @@ def from_filename(cls, filename: str) -> "Plugin":
return cls(ast.parse(source))

def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]:
visitor = Visitor()
visitor.visit(self._tree)
yield from visitor.problems
for v in (Visitor, Visitor102):
visitor = v()
visitor.visit(self._tree)
yield from visitor.problems


TRIO100 = "TRIO100: {} context contains no checkpoints, add `await trio.sleep(0)`"
TRIO101 = "TRIO101: yield inside a nursery or cancel scope is only safe when implementing a context manager - otherwise, it breaks exception handling"
TRIO102 = "TRIO102: it's unsafe to await inside `finally:` unless you use a shielded cancel scope with a timeout"
29 changes: 28 additions & 1 deletion tests/test_flake8_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from hypothesis import HealthCheck, given, settings
from hypothesmith import from_grammar, from_node

from flake8_trio import TRIO100, TRIO101, Error, Plugin, Visitor, make_error
from flake8_trio import TRIO100, TRIO101, TRIO102, Error, Plugin, Visitor, make_error


class Flake8TrioTestCase(unittest.TestCase):
Expand Down FFE9 Expand Up @@ -48,6 +48,33 @@ def test_trio101(self):
make_error(TRIO101, 15, 8),
make_error(TRIO101, 27, 8),
make_error(TRIO101, 38, 8),
make_error(TRIO101, 59, 8),
)

def test_trio102(self):
self.assert_expected_errors(
"trio102.py",
make_error(TRIO102, 24, 8),
make_error(TRIO102, 30, 12),
make_error(TRIO102, 36, 12),
make_error(TRIO102, 62, 12),
make_error(TRIO102, 70, 12),
make_error(TRIO102, 74, 12),
make_error(TRIO102, 76, 12),
make_error(TRIO102, 80, 12),
make_error(TRIO102, 82, 12),
make_error(TRIO102, 84, 12),
make_error(TRIO102, 88, 12),
make_error(TRIO102, 92, 8),
make_error(TRIO102, 94, 8),
make_error(TRIO102, 101, 12),
)

@unittest.skipIf(sys.version_info < (3, 9), "requires 3.9+")
def test_trio102_py39(self):
self.assert_expected_errors(
"trio102_py39.py",
make_error(TRIO102, 15, 12),
)


Expand Down
8 changes: 7 additions & 1 deletion tests/trio101.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import contextlib as bla
from contextlib import asynccontextmanager, contextmanager
from contextlib import asynccontextmanager, contextmanager, contextmanager as blahabla

import trio

Expand Down Expand Up @@ -51,3 +51,9 @@ async def foo7():
def foo8():
with trio.open_nursery() as _:
yield 1 # safe


@blahabla
def foo9():
with trio.open_nursery() as _:
yield 1 # error
Loading
0