7
7
8
8
from __future__ import annotations
9
9
10
+ from abc import ABC , abstractmethod
10
11
from dataclasses import dataclass , field
11
12
from typing import TYPE_CHECKING , Any
12
13
@@ -90,9 +91,114 @@ def checkpoint_statement(library: str) -> cst.SimpleStatementLine:
90
91
)
91
92
92
93
94
+ class CommonVisitors (cst .CSTTransformer , ABC ):
95
+ def __init__ (self ):
96
+ super ().__init__ ()
97
+ self .noautofix : bool = False
98
+ # this one is not save-stated, but I fail to come up with any scenario
99
+ # where that matters
100
+ self .add_statement : cst .SimpleStatementLine | None = None
101
+
102
+ # these are file-wide, so intentionally not save-stated upon entry/exit
103
+ # of functions/loops/etc
104
+ self .explicitly_imported_library : dict [str , bool ] = {
105
+ "trio" : False ,
106
+ "anyio" : False ,
107
+ }
108
+ self .add_import : set [str ] = set ()
109
+
110
+ self .booldepth = 0
111
+
112
+ @property
113
+ @abstractmethod
114
+ def library (self ) -> tuple [str , ...]:
115
+ ...
116
+
117
+ # TODO: generate an error in these two if transforming+visiting is done in a single
118
+ # pass and emit-error-on-transform can be enabled/disabled. The error can't be
119
+ # generated in the yield/return since it doesn't know if it will be autofixed.
120
+
121
+ # instead of trying to exclude yields found in all the weird places from
122
+ # setting self.add_statement, we instead clear it upon each new line
123
+ def visit_SimpleStatementLine (self , node : cst .SimpleStatementLine ):
124
+ self .add_statement = None
125
+
126
+ # insert checkpoint before yield with a flattensentinel, if indicated
127
+ def leave_SimpleStatementLine (
128
+ self ,
129
+ original_node : cst .SimpleStatementLine ,
130
+ updated_node : cst .SimpleStatementLine ,
131
+ ) -> cst .SimpleStatementLine | cst .FlattenSentinel [cst .SimpleStatementLine ]:
132
+ if self .add_statement is None :
133
+ return updated_node
134
+ curr_add_statement = self .add_statement
135
+ self .add_statement = None
136
+
137
+ # multiple statements on a single line is not handled
138
+ # yields in boolops are also not caught by any of the other scenarios
139
+ if len (updated_node .body ) > 1 :
140
+ return updated_node
141
+
142
+ self .ensure_imported_library ()
143
+ return cst .FlattenSentinel ([curr_add_statement , updated_node ])
144
+
145
+ def visit_BooleanOperation (self , node : cst .BooleanOperation ):
146
+ self .booldepth += 1
147
+ self .noautofix = True
148
+
149
+ def leave_BooleanOperation (
150
+ self , original_node : cst .BooleanOperation , updated_node : cst .BooleanOperation
151
+ ):
152
+ assert self .booldepth
153
+ self .booldepth -= 1
154
+ if not self .booldepth :
155
+ self .noautofix = False
156
+ return updated_node
157
+
158
+ def ensure_imported_library (self ) -> None :
159
+ """Mark library for import.
160
+
161
+ Check that the library we'd use to insert checkpoints
162
+ is imported - if not, mark it to be inserted later.
163
+ """
164
+ assert self .library
165
+ if not self .explicitly_imported_library [self .library [0 ]]:
166
+ self .add_import .add (self .library [0 ])
167
+
168
+
169
+ # necessary as we don't know whether to insert checkpoints on the first pass of a loop
170
+ # so we transform the loop body afterwards
171
+ class InsertCheckpointsInLoopBody (CommonVisitors ):
172
+ def __init__ (
173
+ self ,
174
+ nodes_needing_checkpoint : Sequence [cst .Yield | cst .Return ],
175
+ library : tuple [str , ...],
176
+ ):
177
+ super ().__init__ ()
178
+ self .nodes_needing_checkpoint = nodes_needing_checkpoint
179
+ self .__library = library
180
+
181
+ @property
182
+ def library (self ) -> tuple [str , ...]:
183
+ return self .__library if self .__library else ("trio" ,)
184
+
185
+ def leave_Yield (
186
+ self ,
187
+ original_node : cst .Yield ,
188
+ updated_node : cst .Yield ,
189
+ ) -> cst .Yield :
190
+ # we need to check *original* node here, since updated node is a copy
191
+ # which loses identity equality
192
+ if original_node in self .nodes_needing_checkpoint and not self .noautofix :
193
+ self .add_statement = checkpoint_statement (self .library [0 ])
194
+ return updated_node
195
+
196
+ leave_Return = leave_Yield # type: ignore
197
+
198
+
93
199
@error_class_cst
94
200
@disabled_by_default
95
- class Visitor91X (Flake8TrioVisitor_cst ):
201
+ class Visitor91X (Flake8TrioVisitor_cst , CommonVisitors ):
96
202
error_codes = {
97
203
"TRIO910" : (
98
204
"{0} from async function with no guaranteed checkpoint or exception "
@@ -112,18 +218,6 @@ def __init__(self, *args: Any, **kwargs: Any):
112
218
self .uncheckpointed_statements : set [Statement ] = set ()
113
219
self .comp_unknown = False
114
220
115
- # these are file-wide, so intentionally not save-stated upon entry/exit
116
- # of functions/loops/etc
117
- self .explicitly_imported_library : dict [str , bool ] = {
118
- "trio" : False ,
119
- "anyio" : False ,
120
- }
121
- self .add_import : set [str ] = set ()
122
-
123
- # this one is not save-stated, but I fail to come up with any scenario
124
- # where that matters
125
- self .add_statement : cst .SimpleStatementLine | None = None
126
-
127
221
self .loop_state = LoopState ()
128
222
self .try_state = TryState ()
129
223
@@ -168,7 +262,6 @@ def uncheckpointed_before_break(self, value: set[Statement]):
168
262
self .loop_state .uncheckpointed_before_break = value
169
263
170
264
def checkpoint_statement (self ) -> cst .SimpleStatementLine :
171
- self .ensure_imported_library ()
172
265
return checkpoint_statement (self .library [0 ])
173
266
174
267
def visit_FunctionDef (self , node : cst .FunctionDef ) -> None :
@@ -248,7 +341,7 @@ def check_function_exit(
248
341
# Add this as a node potentially needing checkpoints only if it
249
342
# missing checkpoints solely depends on whether the artificial statement is
250
343
# "real"
251
- if len (self .uncheckpointed_statements ) == 1 :
344
+ if len (self .uncheckpointed_statements ) == 1 and not self . noautofix :
252
345
self .loop_state .nodes_needing_checkpoints .append (original_node )
253
346
F987
return False
254
347
@@ -274,38 +367,6 @@ def leave_Return(
274
367
assert original_node .deep_equals (updated_node )
275
368
return original_node
276
369
277
- # TODO: generate an error in these two if transforming+visiting is done in a single
278
- # pass and emit-error-on-transform can be enabled/disabled. The error can't be
279
- # generated in the yield/return since it doesn't know if it will be autofixed.
280
-
281
- # SimpleStatementSuite and multi-statement lines can probably be autofixed, but
282
- # for now just don't insert checkpoints in the wrong place.
283
- def leave_SimpleStatementSuite (
284
- self ,
285
- original_node : cst .SimpleStatementSuite ,
286
- updated_node : cst .SimpleStatementSuite ,
287
- ) -> cst .SimpleStatementSuite :
288
- self .add_statement = None
289
- return updated_node
290
-
291
- # insert checkpoint before yield with a flattensentinel, if indicated
292
- def leave_SimpleStatementLine (
293
- self ,
294
- original_node : cst .SimpleStatementLine ,
295
- updated_node : cst .SimpleStatementLine ,
296
- ) -> cst .SimpleStatementLine | cst .FlattenSentinel [cst .SimpleStatementLine ]:
297
- if self .add_statement is None :
298
- return updated_node
299
-
300
- # multiple statements on a single line is not handled
301
- if len (updated_node .body ) > 1 :
302
- self .add_statement = None
303
- return updated_node
304
-
305
- res = cst .FlattenSentinel ([self .add_statement , updated_node ])
306
- self .add_statement = None
307
- return res # noqa: R504
308
-
309
370
def error_91x (
310
371
self ,
311
372
node : cst .Return | cst .FunctionDef | cst .Yield ,
@@ -356,7 +417,7 @@ def leave_Yield(
356
417
return updated_node
357
418
self .has_yield = True
358
419
359
- if self .check_function_exit (original_node ):
420
+ if self .check_function_exit (original_node ) and not self . noautofix :
360
421
self .add_statement = self .checkpoint_statement ()
361
422
362
423
# mark as requiring checkpoint after
@@ -616,9 +677,8 @@ def leave_While(
616
677
| cst .RemovalSentinel
617
678
):
618
679
if self .loop_state .nodes_needing_checkpoints :
619
- self .ensure_imported_library ()
620
680
transformer = InsertCheckpointsInLoopBody (
621
- self .loop_state .nodes_needing_checkpoints , self .library [ 0 ]
681
+ self .loop_state .nodes_needing_checkpoints , self .library
622
682
)
623
683
# type of updated_node expanded to the return type
624
684
updated_node = updated_node .visit (transformer ) # type: ignore
@@ -651,9 +711,10 @@ def visit_BooleanOperation_right(self, node: cst.BooleanOperation):
651
711
def leave_BooleanOperation_right (self , node : cst .BooleanOperation ):
652
712
if not self .async_function :
653
713
return
654
- self .uncheckpointed_statements .update (
655
- self .outer [ node ][ " uncheckpointed_statements" ]
714
+ self .outer [ node ][ " uncheckpointed_statements" ] .update (
715
+ self .uncheckpointed_statements
656
716
)
717
+ self .restore_state (node )
657
718
658
719
# comprehensions are simpler than loops, since they cannot contain yields
659
720
# or many other complicated statements, but their subfields are not in the order
@@ -735,15 +796,6 @@ def visit_Import(self, node: cst.Import):
735
796
assert isinstance (alias .name .value , str )
736
797
self .explicitly_imported_library [alias .name .value ] = True
737
798
738
- def ensure_imported_library (self ) -> None :
739
- """Mark library for import.
740
-
741
- Check that the library we'd use to insert checkpoints
742
- is imported - if not, mark it to be inserted later.
743
- """
744
- if not self .explicitly_imported_library [self .library [0 ]]:
745
- self .add_import .add (self .library [0 ])
746
-
747
799
def leave_Module (self , original_node : cst .Module , updated_node : cst .Module ):
748
800
"""Add needed library import, if any, to the module."""
749
801
if not self .add_import :
@@ -765,33 +817,3 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module):
765
817
assert len (self .add_import ) == 1
766
818
new_body .insert (index , cst .parse_statement (f"import { self .library [0 ]} " ))
767
819
return updated_node .with_changes (body = new_body )
768
-
769
-
770
- # necessary as we don't know whether to insert checkpoints on the first pass of a loop
771
- # so we transform the loop body afterwards
772
- class InsertCheckpointsInLoopBody (cst .CSTTransformer ):
773
- def __init__ (
774
- self , nodes_needing_checkpoint : Sequence [cst .Yield | cst .Return ], library : str
775
- ):
776
- super ().__init__ ()
777
- self .nodes_needing_checkpoint = nodes_needing_checkpoint
778
- self .add_statement : cst .SimpleStatementLine | None = None
779
- self .library = library
780
-
781
- # insert checkpoint before yield with a flattensentinel, if indicated
782
- # type checkers don't like going across classes, esp as the method accesses
783
- # and modifies self.add_statement, but #YOLO
784
- leave_SimpleStatementLine = Visitor91X .leave_SimpleStatementLine # type: ignore
785
-
786
- def leave_Yield (
787
- self ,
788
- original_node : cst .Yield ,
789
- updated_node : cst .Yield ,
790
- ) -> cst .Yield :
791
- # we need to check *original* node here, since updated node is a copy
792
- # which loses identity equality
793
- if original_node in self .nodes_needing_checkpoint :
794
- self .add_statement = checkpoint_statement (self .library )
795
- return updated_node
796
-
797
- leave_Return = leave_Yield # type: ignore
0 commit comments