11
11
12
12
import ast
13
13
import tokenize
14
- from typing import Any , Dict , Iterable , List , NamedTuple , Optional , Set , Union
14
+ from typing import Any , Dict , Iterable , List , NamedTuple , Optional , Set , Tuple , Union
15
15
16
16
# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
17
17
__version__ = "22.8.3"
29
29
"TRIO108" : "{0} from async iterable with no guaranteed checkpoint since {1.name} on line {1.lineno}" ,
30
30
"TRIO109" : "Async function definition with a `timeout` parameter - use `trio.[fail/move_on]_[after/at]` instead" ,
31
31
"TRIO110" : "`while <condition>: await trio.sleep()` should be replaced by a `trio.Event`." ,
32
- "TRIO302" : "async context manager inside nursery opened on line {}. Nurseries should be outermost. " ,
32
+ "TRIO302" : "call to nursery.start/start_soon with resource from context manager opened on line {} something something nursery on line {} " ,
33
33
}
34
34
35
35
@@ -40,7 +40,7 @@ class Statement(NamedTuple):
40
40
41
41
# ignore col offset since many tests don't supply that
42
42
def __eq__ (self , other : Any ) -> bool :
43
- return isinstance (other , Statement ) and self [:2 ] == other [:2 ]
43
+ return isinstance (other , Statement ) and self [:2 ] == other [:2 ] # type: ignore
44
44
45
45
46
46
HasLineInfo = Union [ast .expr , ast .stmt , ast .arg , ast .excepthandler , Statement ]
@@ -140,10 +140,19 @@ def error(self, error: str, node: HasLineInfo, *args: object):
140
140
if not self .suppress_errors :
141
141
self ._problems .append (Error (error , node .lineno , node .col_offset , * args ))
142
142
143
- def get_state (self , * attrs : str ) -> Dict [str , Any ]:
143
+ def get_state (self , * attrs : str , copy : bool = False ) -> Dict [str , Any ]:
144
144
if not attrs :
145
145
attrs = tuple (self .__dict__ .keys ())
146
- return {attr : getattr (self , attr ) for attr in attrs if<
8000
/span> attr != "_problems" }
146
+ res : Dict [str , Any ] = {}
147
+ for attr in attrs :
148
+ if attr == "_problems" :
149
+ continue
150
+ value = getattr (self , attr )
151
+ if copy and hasattr (value , "copy" ):
152
+ value = value .copy ()
153
+ res [attr ] = value
154
+ return res
155
+ # return {attr: getattr(self, attr) for attr in attrs if attr != "_problems"}
147
156
148
157
def set_state (self , attrs : Dict [str , Any ], copy : bool = False ):
149
158
for attr , value in attrs .items ():
@@ -185,36 +194,41 @@ def __init__(self):
185
194
# variables only used for 101
186
195
self ._yield_is_error = False
187
196
self ._safe_decorator = False
188
- self ._inside_nursery : Optional [ int ] = None
197
+ self ._context_manager_stack : List [ Tuple [ ast . expr , str , bool ]] = []
189
198
190
- # ---- 100, 101 ----
199
+ # ---- 100, 101, 302 ----
191
200
def visit_With (self , node : Union [ast .With , ast .AsyncWith ]):
192
- # 100
193
201
self .check_for_trio100 (node )
194
202
195
- # 101 for rest of function
196
- outer = self .get_state ("_yield_is_error" )
203
+ outer = self .get_state ("_yield_is_error" , "_context_manager_stack" , copy = True )
197
204
198
205
# Check for a `with trio.<scope_creater>`
199
- if not self ._safe_decorator :
200
- for item in (i .context_expr for i in node .items ):
206
+ for item in node .items :
207
+ # 101
208
+ if not self ._safe_decorator and not self ._yield_is_error :
201
209
if (
202
- get_trio_scope (item , "open_nursery" , * cancel_scope_names )
210
+ get_trio_scope (
211
+ item .context_expr , "open_nursery" , * cancel_scope_names
212
+ )
203
213
is not None
204
214
):
205
215<
6D40
/code>
self ._yield_is_error = True
206
- break
216
+ # 302
217
+ if isinstance (item .optional_vars , ast .Name ) and isinstance (
218
+ item .context_expr , ast .Call
219
+ ):
220
+ is_nursery = (
221
+ get_trio_scope (item .context_expr , "open_nursery" ) is not None
222
+ )
223
+ poop = (item .context_expr .func , item .optional_vars .id , is_nursery )
224
+ self ._context_manager_stack .append (poop )
207
225
208
226
self .generic_visit (node )
209
227
210
228
# reset yield_is_error
211
229
self .set_state (outer )
212
230
213
- def visit_AsyncWith (self , node : ast .AsyncWith ):
214
- outer = self ._inside_nursery
215
- self .check_for_trio302 (node .items )
216
- self .visit_With (node )
217
- self ._inside_nursery = outer
231
+ visit_AsyncWith = visit_With
218
232
219
233
# ---- 100 ----
220
234
def check_for_trio100 (self , node : Union [ast .With , ast .AsyncWith ]):
@@ -231,7 +245,7 @@ def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]):
231
245
def visit_FunctionDef (self , node : Union [ast .FunctionDef , ast .AsyncFunctionDef ]):
232
246
outer = self .get_state ()
233
247
self ._yield_is_error = False
234
- self ._inside_nursery = None
248
+ self ._context_manager_stack = []
235
249
236
250
# check for @<context_manager_name> and @<library>.<context_manager_name>
237
251
if has_decorator (node .decorator_list , * context_manager_names ):
@@ -284,16 +298,40 @@ def check_for_110(self, node: ast.While):
284
298
):
285
299
self .error ("TRIO110" , node )
286
300
287
- def check_for_trio302 (self , withitems : List [ast .withitem ]):
288
- calls = [w .context_expr for w in withitems ]
289
- for call in calls :
290
- ss = get_trio_scope (call )
291
- if not ss :
292
- continue
293
- if ss .funcname == "open_nursery" :
294
- self ._inside_nursery = ss .node .lineno
295
- elif self ._inside_nursery is not None :
296
- self .error ("TRIO302" , ss .node , self ._inside_nursery )
301
+ def visit_Call (self , node : ast .Call ):
302
+ def get_id (node : ast .AST ) -> Optional [ast .Name ]:
303
+ if isinstance (node , ast .Name ):
304
+ return node
305
+ if isinstance (node , ast .Attribute ):
306
+ return get_id (node .value )
307
+ if isinstance (node , ast .keyword ):
308
+ return get_id (node .value )
309
+ return None
310
+
311
+ if (
312
+ isinstance (node .func , ast .Attribute )
313
+ and isinstance (node .func .value , ast .Name )
314
+ and node .func .attr in ("start" , "start_soon" )
315
+ ):
316
+ called_vars : Dict [str , ast .Name ] = {}
317
+ for arg in (* node .args , * node .keywords ):
318
+ name = get_id (arg )
319
+ if name :
320
+ called_vars [name .id ] = name
321
+
322
+ nursery_call = None
323
+ for expr , cm_name , is_nursery in self ._context_manager_stack :
324
+ if node .func .value .id == cm_name :
325
+ if not is_nursery :
326
+ break
327
+ nursery_call = expr
328
+ continue
329
+ if nursery_call is None :
330
+ continue
331
+ if cm_name in called_vars :
332
+ self .error ("TRIO302" , node , expr .lineno , nursery_call .lineno )
333
+
334
+ self .generic_visit (node )
297
335
298
336
299
337
def critical_except (node : ast .ExceptHandler ) -> Optional [Statement ]:
0 commit comments