8000 rework 302 according to new specifications · python-trio/flake8-async@f4f5f16 · GitHub
[go: up one dir, main page]

Skip to content
Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f4f5f16

Browse files
committed
rework 302 according to new specifications
1 parent 5c344bb commit f4f5f16

File tree

3 files changed

+240
-78
lines changed

3 files changed

+240
-78
lines changed

flake8_trio.py

Lines changed: 68 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import ast
1313
import tokenize
14-
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Set, Union
14+
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union
1515

1616
# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
1717
__version__ = "22.8.3"
@@ -29,7 +29,7 @@
2929
"TRIO108": "{0} from async iterable with no guaranteed checkpoint since {1.name} on line {1.lineno}",
3030
"TRIO109": "Async function definition with a `timeout` parameter - use `trio.[fail/move_on]_[after/at]` instead",
3131
"TRIO110": "`while <condition>: await trio.sleep()` should be replaced by a `trio.Event`.",
32-
"TRIO302": "async context manager inside nursery opened on line {}. Nurseries should be outermost.",
32+
"TRIO302": "call to nursery.start/start_soon with resource from context manager opened on line {} something something nursery on line {}",
3333
}
3434

3535

@@ -40,7 +40,7 @@ class Statement(NamedTuple):
4040

4141
# ignore col offset since many tests don't supply that
4242
def __eq__(self, other: Any) -> bool:
43-
return isinstance(other, Statement) and self[:2] == other[:2]
43+
return isinstance(other, Statement) and self[:2] == other[:2] # type: ignore
4444

4545

4646
HasLineInfo = Union[ast.expr, ast.stmt, ast.arg, ast.excepthandler, Statement]
@@ -140,10 +140,19 @@ def error(self, error: str, node: HasLineInfo, *args: object):
140140
if not self.suppress_errors:
141141
self._problems.append(Error(error, node.lineno, node.col_offset, *args))
142142

143-
def get_state(self, *attrs: str) -> Dict[str, Any]:
143+
def get_state(self, *attrs: str, copy: bool = False) -> Dict[str, Any]:
144144
if not attrs:
145145
attrs = tuple(self.__dict__.keys())
146-
return {attr: getattr(self, attr) for attr in attrs if< 8000 /span> attr != "_problems"}
146+
res: Dict[str, Any] = {}
147+
for attr in attrs:
148+
if attr == "_problems":
149+
continue
150+
value = getattr(self, attr)
151+
if copy and hasattr(value, "copy"):
152+
value = value.copy()
153+
res[attr] = value
154+
return res
155+
# return {attr: getattr(self, attr) for attr in attrs if attr != "_problems"}
147156

148157
def set_state(self, attrs: Dict[str, Any], copy: bool = False):
149158
for attr, value in attrs.items():
@@ -185,36 +194,41 @@ def __init__(self):
185194
# variables only used for 101
186195
self._yield_is_error = False
187196
self._safe_decorator = False
188-
self._inside_nursery: Optional[int] = None
197+
self._context_manager_stack: List[Tuple[ast.expr, str, bool]] = []
189198

190-
# ---- 100, 101 ----
199+
# ---- 100, 101, 302 ----
191200
def visit_With(self, node: Union[ast.With, ast.AsyncWith]):
192-
# 100
193201
self.check_for_trio100(node)
194202

195-
# 101 for rest of function
196-
outer = self.get_state("_yield_is_error")
203+
outer = self.get_state("_yield_is_error", "_context_manager_stack", copy=True)
197204

198205
# Check for a `with trio.<scope_creater>`
199-
if not self._safe_decorator:
200-
for item in (i.context_expr for i in node.items):
206+
for item in node.items:
207+
# 101
208+
if not self._safe_decorator and not self._yield_is_error:
201209
if (
202-
get_trio_scope(item, "open_nursery", *cancel_scope_names)
210+
get_trio_scope(
211+
item.context_expr, "open_nursery", *cancel_scope_names
212+
)
203213
is not None
204214
):
205215< 6D40 /code>
self._yield_is_error = True
206-
break
216+
# 302
217+
if isinstance(item.optional_vars, ast.Name) and isinstance(
218+
item.context_expr, ast.Call
219+
):
220+
is_nursery = (
221+
get_trio_scope(item.context_expr, "open_nursery") is not None
222+
)
223+
poop = (item.context_expr.func, item.optional_vars.id, is_nursery)
224+
self._context_manager_stack.append(poop)
207225

208226
self.generic_visit(node)
209227

210228
# reset yield_is_error
211229
self.set_state(outer)
212230

213-
def visit_AsyncWith(self, node: ast.AsyncWith):
214-
outer = self._inside_nursery
215-
self.check_for_trio302(node.items)
216-
self.visit_With(node)
217-
self._inside_nursery = outer
231+
visit_AsyncWith = visit_With
218232

219233
# ---- 100 ----
220234
def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]):
@@ -231,7 +245,7 @@ def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]):
231245
def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
232246
outer = self.get_state()
233247
self._yield_is_error = False
234-
self._inside_nursery = None
248+
self._context_manager_stack = []
235249

236250
# check for @<context_manager_name> and @<library>.<context_manager_name>
237251
if has_decorator(node.decorator_list, *context_manager_names):
@@ -284,16 +298,40 @@ def check_for_110(self, node: ast.While):
284298
):
285299
self.error("TRIO110", node)
286300

287-
def check_for_trio302(self, withitems: List[ast.withitem]):
288-
calls = [w.context_expr for w in withitems]
289-
for call in calls:
290-
ss = get_trio_scope(call)
291-
if not ss:
292-
continue
293-
if ss.funcname == "open_nursery":
294-
self._inside_nursery = ss.node.lineno
295-
elif self._inside_nursery is not None:
296-
self.error("TRIO302", ss.node, self._inside_nursery)
301+
def visit_Call(self, node: ast.Call):
302+
def get_id(node: ast.AST) -> Optional[ast.Name]:
303+
if isinstance(node, ast.Name):
304+
return node
305+
if isinstance(node, ast.Attribute):
306+
return get_id(node.value)
307+
if isinstance(node, ast.keyword):
308+
return get_id(node.value)
309+
return None
310+
311+
if (
312+
isinstance(node.func, ast.Attribute)
313+
and isinstance(node.func.value, ast.Name)
314+
and node.func.attr in ("start", "start_soon")
315+
):
316+
called_vars: Dict[str, ast.Name] = {}
317+
for arg in (*node.args, *node.keywords):
318+
name = get_id(arg)
319+
if name:
320+
called_vars[name.id] = name
321+
322+
nursery_call = None
323+
for expr, cm_name, is_nursery in self._context_manager_stack:
324+
if node.func.value.id == cm_name:
325+
if not is_nursery:
326+
break
327+
nursery_call = expr
328+
continue
329+
if nursery_call is None:
330+
continue
331+
if cm_name in called_vars:
332+
self.error("TRIO302", node, expr.lineno, nursery_call.lineno)
333+
334+
self.generic_visit(node)
297335

298336

299337
def critical_except(node: ast.ExceptHandler) -> Optional[Statement]:

tests/test_flake8_trio.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ def test_eval(test: str, path: str):
7878
except Exception as e:
7979
print(f"lineno: {lineno}, line: {line}", file=sys.stderr)
8080
raise e
81-
col, *args = args
81+
if args:
82+
col, *args = args
83+
else:
84+
col = 0
8285
assert isinstance(
8386
col, int
8487
), f'invalid column "{col}" @L{lineno}, in "{line}"'
@@ -115,13 +118,15 @@ def assert_expected_errors(test_file: str, include: Iterable[str], *expected: Er
115118

116119
def print_first_diff(errors: Sequence[Error], expected: Sequence[Error]):
117120
first_error_line: List[Error] = []
118-
for e in errors:
119-
if e.line == errors[0].line:
120-
first_error_line.append(e)
121121
first_expected_line: List[Error] = []
122-
for e in expected:
123-
if e.line == expected[0].line:
124-
first_expected_line.append(e)
122+
for err, exp in zip(errors, expected):
123+
if err == exp:
124+
continue
125+
if not first_error_line or err.line == first_error_line[0]:
126+
first_error_line.append(err)
127+
if not first_expected_line or exp.line == first_expected_line[0]:
128+
first_expected_line.append(exp)
129+
125130
if first_expected_line != first_error_line:
126131
print(
127132
"First lines with different errors",

0 commit comments

Comments
 (0)
0