8
8
from __future__ import annotations
9
9
10
10
from dataclasses import dataclass , field
11
- from typing import Any
11
+ from typing import TYPE_CHECKING , Any
12
12
13
13
import libcst as cst
14
14
import libcst .matchers as m
24
24
iter_guaranteed_once_cst ,
25
25
)
26
26
27
+ if TYPE_CHECKING :
28
+ from collections .abc import Sequence
29
+
30
+
27
31
ARTIFICIAL_STATEMENT = Statement ("artificial" , - 1 )
28
32
29
33
34
+ def checkpoint_statement () -> cst .SimpleStatementLine :
35
+ return cst .SimpleStatementLine (
36
+ [cst .Expr (cst .parse_expression ("trio.lowlevel.checkpoint()" ))]
37
+ )
38
+
39
+
30
40
def func_empty_body (node : cst .FunctionDef ) -> bool :
31
41
# Does the function body consist solely of `pass`, `...`, and (doc)string literals?
32
42
empty_statement = m .Pass () | m .Expr (m .Ellipsis () | m .SimpleString ())
@@ -46,8 +56,10 @@ class LoopState:
46
56
47
57
uncheckpointed_before_continue : set [Statement ] = field (default_factory = set )
48
58
uncheckpointed_before_break : set [Statement ] = field (default_factory = set )
49
- artificial_errors : set [cst .Return | cst .FunctionDef | cst .Yield ] = field (
50
- default_factory = set
59
+
60
+ artificial_errors : set [cst .Return | cst .Yield ] = field (default_factory = set )
61
+ nodes_needing_checkpoints : list [cst .Yield | cst .Return ] = field (
62
+ default_factory = list
51
63
)
52
64
53
65
def copy (self ):
@@ -58,6 +70,7 @@ def copy(self):
58
70
uncheckpointed_before_continue = self .uncheckpointed_before_continue .copy (),
59
71
uncheckpointed_before_break = self .uncheckpointed_before_break .copy (),
60
72
artificial_errors = self .artificial_errors .copy (),
73
+ nodes_needing_checkpoints = self .nodes_needing_checkpoints .copy (),
61
74
)
62
75
63
76
@@ -99,6 +112,10 @@ def __init__(self, *args: Any, **kwargs: Any):
99
112
self .uncheckpointed_statements : set [Statement ] = set ()
100
113
self .comp_unknown = False
101
114
115
+ # this one is not save-stated, but I fail to come up with any scenario
116
+ # where that matters
117
+ self .add_statement : cst .SimpleStatementLine | None = None
118
+
102
119
self .loop_state = LoopState ()
103
120
self .try_state = TryState ()
104
121
@@ -142,14 +159,6 @@ def uncheckpointed_before_break(self) -> set[Statement]:
142
159
def uncheckpointed_before_break (self , value : set [Statement ]):
143
160
self .loop_state .uncheckpointed_before_break = value
144
161
145
- @property
146
- def artificial_errors (self ) -> set [cst .Return | cst .FunctionDef | cst .Yield ]:
147
- return self .loop_state .artificial_errors
148
-
149
- @artificial_errors .setter
150
- def artificial_errors (self , value : set [cst .Return | cst .FunctionDef | cst .Yield ]):
151
- self .loop_state .artificial_errors = value # pragma: no cover
152
-
153
162
def visit_FunctionDef (self , node : cst .FunctionDef ) -> None :
154
163
# don't lint functions whose bodies solely consist of pass or ellipsis
155
164
if func_has_decorator (node , "overload" , "fixture" ) or func_empty_body (node ):
@@ -186,38 +195,93 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
186
195
def leave_FunctionDef (
187
196
self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef
188
197
) -> cst .FunctionDef :
198
+ any_error = False
189
199
if self .async_function :
190
200
# updated_node does not have a Position, so we must send original_node
191
- self .check_function_exit (original_node )
201
+ any_error = self .check_function_exit (original_node )
192
202
self .restore_state (original_node )
203
+ if (
204
+ any_error
205
+ and self .options .autofix
206
+ and isinstance (updated_node .body , cst .IndentedBlock )
207
+ ):
208
+ new_body = list (updated_node .body .body ) + [checkpoint_statement ()]
209
+ indentedblock = updated_node .body .with_changes (body = new_body )
210
+ return updated_node .with_changes (body = indentedblock )
193
211
return updated_node
194
212
195
213
# error if function exits or returns with uncheckpointed statements
196
- def check_function_exit (self , node : cst .Return | cst .FunctionDef ):
214
+ def check_function_exit (
215
+ self ,
216
+ original_node : cst .Return | cst .FunctionDef | cst .Yield ,
217
+ ) -> bool :
218
+ any_error = False
219
+
220
+ # Add this as a node potentially needing checkpoints only if it
221
+ # solely depends on whether the artificial statement is "real"
222
+ if self .uncheckpointed_statements == {ARTIFICIAL_STATEMENT }:
223
+ assert isinstance (original_node , (cst .Return , cst .Yield ))
224
+ self .loop_state .nodes_needing_checkpoints .append (original_node )
225
+
197
226
for statement in self .uncheckpointed_statements :
198
- self .error_91x (node , statement )
227
+ any_error |= self .error_91x (original_node , statement )
228
+ return any_error
199
229
200
230
def leave_Return (
201
231
self , original_node : cst .Return , updated_node : cst .Return
202
232
) -> cst .Return :
203
233
if not self .async_function :
204
234
return updated_node
205
- self .check_function_exit (original_node )
235
+ if self .check_function_exit (original_node ):
236
+ self .add_statement = checkpoint_statement ()
206
237
# avoid duplicate error messages
207
238
self .uncheckpointed_statements = set ()
208
239
240
+ # return original node to avoid problems with identity equality
241
+ assert original_node .deep_equals (updated_node )
242
+ return original_node
243
+
244
+ # this could probably be handled, but I'll settle for not inserting the checkpoint
245
+ # in the wrong place
246
+ def leave_SimpleStatementSuite (
247
+ self ,
248
+ original_node : cst .SimpleStatementSuite ,
249
+ updated_node : cst .SimpleStatementSuite ,
250
+ ) -> cst .SimpleStatementSuite :
251
+ self .add_statement = None
209
252
return updated_node
210
253
254
+ def leave_SimpleStatementLine (
255
+ self ,
256
+ original_node : cst .SimpleStatementLine ,
257
+ updated_node : cst .SimpleStatementLine ,
258
+ ) -> cst .SimpleStatementLine | cst .FlattenSentinel [cst .SimpleStatementLine ]:
259
+ if self .add_statement is None :
260
+ return updated_node
261
+ # multiple statements on a single line is not handled
262
+ # TODO: generate an error here if transforming+visiting is done in a single pass
263
+ # and emit-error-on-transform is enabled
264
+ if len (updated_node .body ) > 1 :
265
+ self .add_statement = None
266
+ return updated_node
267
+
268
+ res = cst .FlattenSentinel ([self .add_statement , updated_node ])
269
+ self .add_statement = None
270
+ return res # noqa: R504
271
+
211
272
def error_91x (
212
- self , node : cst .Return | cst .FunctionDef | cst .Yield , statement : Statement
213
- ):
273
+ self ,
274
+ node : cst .Return | cst .FunctionDef | cst .Yield ,
275
+ statement : Statement ,
276
+ ) -> bool :
214
277
# artificial statement is injected in visit_While_body to make sure errors
215
278
# are raised on multiple loops, if e.g. the end of a loop is uncheckpointed.
216
279
# Here we add it to artificial errors, so loop logic can later turn it into
217
280
# a real error if needed.
218
281
if statement == ARTIFICIAL_STATEMENT :
282
+ assert isinstance (node , (cst .Return , cst .Yield ))
219
283
self .loop_state .artificial_errors .add (node )
220
- return
284
+ return False
221
285
if isinstance (node , cst .FunctionDef ):
222
286
msg = "exit"
223
287
else :
@@ -229,6 +293,7 @@ def error_91x(
229
293
statement ,
230
294
error_code = "TRIO911" if self .has_yield else "TRIO910" ,
231
295
)
296
+ return True
232
297
233
298
def leave_Await (
234
299
self , original_node : cst .Await , updated_node : cst .Await
@@ -260,13 +325,16 @@ def leave_Yield(
260
325
if not self .async_function :
261
326
return updated_node
262
327
self .has_yield = True
263
- for statement in self .uncheckpointed_statements :
264
- self .error_91x (original_node , statement )
328
+
329
+ if self .check_function_exit (original_node ):
330
+ self .add_statement = checkpoint_statement ()
265
331
266
332
# mark as requiring checkpoint after
267
333
pos = self .get_metadata (PositionProvider , original_node ).start
268
334
self .uncheckpointed_statements = {Statement ("yield" , pos .line , pos .column )}
269
- return updated_node
335
+ # return original to avoid problems with identity equality
336
+ assert original_node .deep_equals (updated_node )
337
+ return original_node
270
338
271
339
# valid checkpoint if there's valid checkpoints (or raise) in:
272
340
# (try or else) and all excepts, or in finally
@@ -283,6 +351,7 @@ def visit_Try(self, node: cst.Try):
283
351
self .try_state .body_uncheckpointed_statements = (
284
352
self .uncheckpointed_statements .copy ()
285
353
)
354
+ # yields inside `try` can always be uncheckpointed
286
355
for inner_node in m .findall (node .body , m .Yield ()):
287
356
pos = self .get_metadata (PositionProvider , inner_node ).start
288
357
self .try_state .body_uncheckpointed_statements .add (
@@ -394,6 +463,9 @@ def visit_While(self, node: cst.While | cst.For):
394
463
self .loop_state = LoopState ()
395
464
self .infinite_loop = self .body_guaranteed_once = False
396
465
466
+ # big bug not having this, TODO: add test case
467
+ visit_For = visit_While
468
+
397
469
# Check for yields w/o checkpoint in between due to entering loop body the first time,
398
470
# after completing all of loop body, and after any continues.
399
471
# yield in else have same requirement
@@ -409,7 +481,7 @@ def visit_While_test(self, node: cst.While):
409
481
def visit_For_iter (self , node : cst .For ):
410
482
self .body_guaranteed_once = iter_guaranteed_once_cst (node .iter )
411
483
412
- def visit_While_body (self , node : cst .While | cst .For ):
484
+ def visit_While_body (self , node : cst .For | cst .While ):
413
485
if not self .async_function :
414
486
return
415
487
@@ -432,7 +504,7 @@ def visit_While_body(self, node: cst.While | cst.For):
432
504
433
505
visit_For_body = visit_While_body
434
506
435
- def leave_While_body (self , node : cst .While | cst .For ):
507
+ def leave_While_body (self , node : cst .For | cst .While ):
436
508
if not self .async_function :
437
509
return
438
510
# if there's errors due to the artificial statement
@@ -442,11 +514,16 @@ def leave_While_body(self, node: cst.While | cst.For):
442
514
self .outer [node ]["uncheckpointed_statements" ]
443
515
| self .uncheckpointed_statements
444
516
)
445
- for err_node in self .artificial_errors :
517
+ any_error = False
518
+
519
+ for err_node in self .loop_state .artificial_errors :
446
520
for stmt in (
447
521
new_uncheckpointed_statements | self .uncheckpointed_before_continue
448
522
):
523
+ any_error = True
449
524
self .error_91x (err_node , stmt )
525
+ if not any_error :
526
+ self .loop_state .nodes_needing_checkpoints = []
450
527
451
528
# replace artificial in break with prebody uncheckpointed statements
452
529
for stmts in (
@@ -479,7 +556,7 @@ def leave_While_body(self, node: cst.While | cst.For):
479
556
480
557
leave_For_body = leave_While_body
481
558
482
- def leave_While_orelse (self , node : cst .While | cst .For ):
559
+ def leave_While_orelse (self , node : cst .For | cst .While ):
483
560
if not self .async_function :
484
561
return
485
562
# if this is an infinite loop, with no break in it, don't raise
@@ -494,10 +571,30 @@ def leave_While_orelse(self, node: cst.While | cst.For):
494
571
495
572
# reset break & continue in case of nested loops
496
573
self .outer [node ]["uncheckpointed_statements" ] = self .uncheckpointed_statements
497
- self .restore_state (node )
498
574
499
575
leave_For_orelse = leave_While_orelse
500
576
577
+ def leave_While (
578
+ self , original_node : cst .For | cst .While , updated_node : cst .For | cst .While
579
+ ) -> (
580
+ cst .While
581
+ | cst .For
582
+ | cst .FlattenSentinel [cst .For | cst .While ]
583
+ | cst .RemovalSentinel
584
+ ):
585
+ if self .loop_state .nodes_needing_checkpoints :
586
+ transformer = InsertCheckpointsInLoopBody (
587
+ self .loop_state .nodes_needing_checkpoints
588
+ )
589
+ res = updated_node .visit (transformer )
590
+ else :
591
+ res = updated_node
592
+
593
+ self .restore_state (original_node )
594
+ return res # noqa: R504
595
+
596
+ leave_For = leave_While
597
+
501
598
# save state in case of continue/break at a point not guaranteed to checkpoint
502
599
def visit_Continue (self , node : cst .Continue ):
503
600
if not self .async_function :
@@ -590,3 +687,27 @@ def visit_CompFor(self, node: cst.CompFor):
590
687
# ignore their content
591
688
def visit_GeneratorExp (self , node : cst .GeneratorExp ):
592
689
return False
690
+
691
+
692
+ # necessary as we don't know whether to insert checkpoints on the first pass of a loop
693
+ # so we transform the loop body afterwards
694
+ class InsertCheckpointsInLoopBody (cst .CSTTransformer ):
695
+ def __init__ (self , nodes_needing_checkpoint : Sequence [cst .Yield | cst .Return ]):
696
+ super ().__init__ ()
697
+ self .nodes_needing_checkpoint = nodes_needing_checkpoint
698
+ self .add_statement : cst .SimpleStatementLine | None = None
699
+
700
+ leave_SimpleStatementLine = Visitor91X .leave_SimpleStatementLine # type: ignore
701
+
702
+ def leave_Yield (
703
+ self ,
704
+ original_node : cst .Yield ,
705
+ updated_node : cst .Yield ,
706
+ ) -> cst .Yield :
707
+ # we need to check *original* node here, since updated node is a copy
708
+ # which loses identity equality
709
+ if original_node in self .nodes_needing_checkpoint :
710
+ self .add_statement = checkpoint_statement ()
711
+ return updated_node
712
+
713
+ leave_Return = leave_Yield # type: ignore
0 commit comments