11
11
12
12
import ast
13
13
import tokenize
14
- from typing import Any , Dict , Iterable , List , NamedTuple , Optional , Set , Union
14
+ from typing import (
15
+ Any ,
16
+ Dict ,
17
+ Iterable ,
18
+ List ,
19
+ NamedTuple ,
20
+ Optional ,
21
+ Set ,
22
+ Tuple ,
23
+ Union ,
24
+ cast ,
25
+ )
15
26
16
27
# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
17
28
__version__ = "22.8.4"
44
55
"`trio.[fail/move_on]_[after/at]` instead"
45
56
),
46
57
"TRIO110" : "`while <condition>: await trio.sleep()` should be replaced by a `trio.Event`." ,
58
+ "TRIO112" : "Redundant nursery {}, consider replacing with a regular function call" ,
47
59
}
48
60
49
61
@@ -63,32 +75,21 @@ def __eq__(self, other: Any) -> bool:
63
75
)
64
76
65
77
66
- HasLineInfo = Union [ast .expr , ast .stmt , ast .arg , ast .excepthandler , Statement ]
67
-
68
-
69
- class TrioScope :
70
- def __init__ (self , node : ast .Call , funcname : str ):
71
- self .node = node
72
- self .funcname = funcname
73
- self .variable_name : Optional [str ] = None
74
- self .shielded : bool = False
75
- self .has_timeout : bool = True
78
+ HasLineCol = Union [ast .expr , ast .stmt , ast .arg , ast .excepthandler , Statement ]
76
79
77
- # scope.shield is assigned to in visit_Assign
78
80
79
- if self .funcname == "CancelScope" :
80
- self .has_timeout = False
81
- for kw in node .keywords :
82
- # Only accepts constant values
83
- if kw .arg == "shield" and isinstance (kw .value , ast .Constant ):
84
- self .shielded = kw .value .value
85
- # sets to True even if timeout is explicitly set to inf
86
- if kw .arg == "deadline" :
87
- self .has_timeout = True
88
-
89
- def __str__ (self ):
90
- # Not supporting other ways of importing trio, per TRIO106
91
- return f"trio.{ self .funcname } "
81
+ def get_matching_call (
82
+ node : ast .AST , * names : str , base : str = "trio"
83
+ ) -> Optional [Tuple [ast .Call , str ]]:
84
+ if (
85
+ isinstance (node , ast .Call )
86
+ and isinstance (node .func , ast .Attribute )
87
+ and isinstance (node .func .value , ast .Name )
88
+ and node .func .value .id == base
89
+ and node .func .attr in names
90
+ ):
91
+ return node , node .func .attr
92
+ return None
92
93
93
94
94
95
class Error :
@@ -157,7 +158,7 @@ def visit_nodes(
157
158
for node in arg :
158
159
visit (node )
159
160
160
- def error (self , error : str , node : HasLineInfo , * args : object ):
161
+ def error (self , error : str , node : HasLineCol , * args : object ):
161
162
if not self .suppress_errors :
162
163
self ._problems .append (Error (error , node .lineno , node .col_offset , * args ))
163
164
@@ -177,18 +178,6 @@ def walk(self, *body: ast.AST) -> Iterable[ast.AST]:
177
178
yield from ast .walk (b )
178
179
179
180
180
- def get_trio_scope (node : ast .AST , * names : str ) -> Optional [TrioScope ]:
181
- if (
182
- isinstance (node , ast .Call )
183
- and isinstance (node .func , ast .Attribute )
184
- and isinstance (node .func .value , ast .Name )
185
- and node .func .value .id == "trio"
186
- and node .func .attr in names
187
- ):
188
- return TrioScope (node , node .func .attr )
189
- return None
190
-
191
-
192
181
def has_decorator (decorator_list : List [ast .expr ], * names : str ):
193
182
for dec in decorator_list :
194
183
if (isinstance (dec , ast .Name ) and dec .id in names ) or (
@@ -211,6 +200,7 @@ def __init__(self):
211
200
def visit_With (self , node : Union [ast .With , ast .AsyncWith ]):
212
201
# 100
213
202
self .check_for_trio100 (node )
203
+ self .check_for_trio112 (node )
214
204
215
205
# 101 for rest of function
216
206
outer = self .get_state ("_yield_is_error" )
@@ -219,7 +209,7 @@ def visit_With(self, node: Union[ast.With, ast.AsyncWith]):
219
209
if not self ._safe_decorator :
220
210
for item in (i .context_expr for i in node .items ):
221
211
if (
222
- get_trio_scope (item , "open_nursery" , * cancel_scope_names )
212
+ get_matching_call (item , "open_nursery" , * cancel_scope_names )
223
213
is not None
224
214
):
225
215
self ._yield_is_error = True
@@ -234,14 +224,14 @@ def visit_With(self, node: Union[ast.With, ast.AsyncWith]):
234
224
235
225
# ---- 100 ----
236
226
def check_for_trio100 (self , node : Union [ast .With , ast .AsyncWith ]):
237
- # Context manager with no `await` call within
227
+ # Context manager with no `await trio.X ` call within
238
228
for item in (i .context_expr for i in node .items ):
239
- call = get_trio_scope (item , * cancel_scope_names )
229
+ call = get_matching_call (item , * cancel_scope_names )
240
230
if call and not any (
241
231
isinstance (x , checkpoint_node_types ) and x != node
242
232
for x in ast .walk (node )
243
233
):
244
- self .error ("TRIO100" , item , str ( call ) )
234
+ self .error ("TRIO100" , item , f"trio. { call [ 1 ] } " )
245
235
246
236
# ---- 101 ----
247
237
def visit_FunctionDef (self , node : Union [ast .FunctionDef , ast .AsyncFunctionDef ]):
@@ -258,7 +248,7 @@ def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
258
248
259
249
# ---- 101, 109 ----
260
250
def visit_AsyncFunctionDef (self , node : ast .AsyncFunctionDef ):
261
- self .check_109 (node )
251
+ self .check_for_trio109 (node )
262
252
self .visit_FunctionDef (node )
263
253
264
254
# ---- 101 ----
@@ -269,7 +259,7 @@ def visit_Yield(self, node: ast.Yield):
269
259
self .generic_visit (node )
270
260
271
261
# ---- 109 ----
272
F987
- def check_109 (self , node : ast .AsyncFunctionDef ):
262
+ def check_for_trio109 (self , node : ast .AsyncFunctionDef ):
273
263
if node .decorator_list :
274
264
return
275
265
args = node .args
@@ -290,18 +280,48 @@ def visit_Import(self, node: ast.Import):
290
280
291
281
# ---- 110 ----
292
282
def visit_While (self , node : ast .While ):
293
- self .check_for_110 (node )
283
+ self .check_for_trio110 (node )
294
284
self .generic_visit (node )
295
285
296
- def check_for_110 (self , node : ast .While ):
286
+ def check_for_trio110 (self , node : ast .While ):
297
287
if (
298
288
len (node .body ) == 1
299
289
and isinstance (node .body [0 ], ast .Expr )
300
290
and isinstance (node .body [0 ].value , ast .Await )
301
- and get_trio_scope (node .body [0 ].value .value , "sleep" , "sleep_until" )
291
+ and get_matching_call (node .body [0 ].value .value , "sleep" , "sleep_until" )
302
292
):
303
293
self .error ("TRIO110" , node )
304
294
295
+ # if with has a withitem `trio.open_nursery() as <X>`,
296
+ # and the body is only a single expression <X>.start[_soon](),
297
+ # and does not pass <X> as a parameter to the expression
298
+ def check_for_trio112 (self , node : Union [ast .With , ast .AsyncWith ]):
299
+ # body is single expression
300
+ if len (node .body ) != 1 or not isinstance (node .body [0 ], ast .Expr ):
301
+ return
302
+ for item in node .items :
303
+ # get variable name <X>
304
+ if not isinstance (item .optional_vars , ast .Name ):
305
+ continue
306
+ var_name = item .optional_vars .id
307
+
308
+ # check for trio.open_nursery
309
+ nursery = get_matching_call (item .context_expr , "open_nursery" )
310
+
311
+ # isinstance(..., ast.Call) is done in get_matching_call
312
+ body_call = cast (ast .Call , node .body [0 ].value )
313
+
314
+ if (
315
+ nursery is not None
316
+ and get_matching_call (body_call , "start" , "start_soon" , base = var_name )
317
+ # check for presence of <X> as parameter
318
+ and not any (
319
+ (isinstance (n , ast .Name ) and n .id == var_name )
320
+ for n in self .walk (* body_call .args , * body_call .keywords )
321
+ )
322
+ ):
323
+ self .error ("TRIO112" , item .context_expr , var_name )
324
+
305
325
306
326
def critical_except (node : ast .ExceptHandler ) -> Optional [Statement ]:
307
327
def has_exception (node : Optional [ast .expr ]) -> str :
@@ -333,10 +353,30 @@ def has_exception(node: Optional[ast.expr]) -> str:
333
353
334
354
335
355
class Visitor102 (Flake8TrioVisitor ):
356
+ class TrioScope :
357
+ def __init__ (self , node : ast .Call , funcname : str ):
358
+ self .node = node
359
+ self .funcname = funcname
360
+ self .variable_name : Optional [str ] = None
361
+ self .shielded : bool = False
362
+ self .has_timeout : bool = True
363
+
364
+ # scope.shielded is assigned to in visit_Assign
365
+
366
+ if self .funcname == "CancelScope" :
367
+ self .has_timeout = False
368
+ for kw in node .keywords :
369
+ # Only accepts constant values
370
+ if kw .arg == "shield" and isinstance (kw .value , ast .Constant ):
371
+ self .shielded = kw .value .value
372
+ # sets to True even if timeout is explicitly set to inf
373
+ if kw .arg == "deadline" :
374
+ self .has_timeout = True
375
+
336
376
def __init__ (self ):
337
377
super ().__init__ ()
338
378
self ._critical_scope : Optional [Statement ] = None
339
- self ._trio_context_managers : List [TrioScope ] = []
379
+ self ._trio_context_managers : List [Visitor102 . TrioScope ] = []
340
380
self ._safe_decorator = False
341
381
342
382
# if we're inside a finally, and not inside a context_manager, and we're not
@@ -364,17 +404,19 @@ def visit_With(self, node: Union[ast.With, ast.AsyncWith]):
364
404
365
405
# Check for a `with trio.<scope_creater>`
366
406
for item in node .items :
367
- trio_scope = get_trio_scope (
407
+ call = get_matching_call (
368
408
item .context_expr , "open_nursery" , * cancel_scope_names
369
409
)
370
- if trio_scope is None :
410
+ if call is None :
371
411
continue
372
412
373
- self ._trio_context_managers .append (trio_scope )
374
- has_context_manager = True
413
+ trio_scope = self .TrioScope (* call )
375
414
# check if it's saved in a variable
376
415
if isinstance (item .optional_vars , ast .Name ):
377
416
trio_scope .variable_name = item .optional_vars .id
417
+
418
+ self ._trio_context_managers .append (trio_scope )
419
+ has_context_manager = True
378
420
break
379
421
380
422
self .generic_visit (node )
0 commit comments