@@ -3081,21 +3081,6 @@ def test_cli_file_input(self):
3081
3081
3082
3082
3083
3083
class ASTOptimiziationTests (unittest .TestCase ):
3084
- binop = {
3085
- "+" : ast .Add (),
3086
- "-" : ast .Sub (),
3087
- "*" : ast .Mult (),
3088
- "/" : ast .Div (),
3089
- "%" : ast .Mod (),
3090
- "<<" : ast .LShift (),
3091
- ">>" : ast .RShift (),
3092
- "|" : ast .BitOr (),
3093
- "^" : ast .BitXor (),
3094
- "&" : ast .BitAnd (),
3095
- "//" : ast .FloorDiv (),
3096
- "**" : ast .Pow (),
3097
- }
3098
-
3099
3084
unaryop = {
3100
3085
"~" : ast .Invert (),
3101
3086
"+" : ast .UAdd (),
@@ -3298,6 +3283,74 @@ def test_folding_type_param_in_type_alias(self):
3298
3283
)
3299
3284
self .assert_ast (result_code , non_optimized_target , optimized_target )
3300
3285
3286
+ def test_folding_match_case_allowed_expressions (self ):
3287
+ source = textwrap .dedent ("""
3288
+ match 0:
3289
+ case -0: pass
3290
+ case -0.1: pass
3291
+ case -0j: pass
3292
+ case 1 + 2j: pass
3293
+ case 1 - 2j: pass
3294
+ case 1.1 + 2.1j: pass
3295
+ case 1.1 - 2.1j: pass
3296
+ case -0 + 1j: pass
3297
+ case -0 - 1j: pass
3298
+ case -0.1 + 1.1j: pass
3299
+ case -0.1 - 1.1j: pass
3300
+ case {-0: 0}: pass
3301
+ case {-0.1: 0}: pass
3302
+ case {-0j: 0}: pass
3303
+ case {1 + 2j: 0}: pass
3304
+ case {1 - 2j: 0}: pass
3305
+ case {1.1 + 2.1j: 0}: pass
3306
+ case {1.1 - 2.1j: 0}: pass
3307
+ case {-0 + 1j: 0}: pass
3308
+ case {-0 - 1j: 0}: pass
3309
+ case {-0.1 + 1.1j: 0}: pass
3310
+ case {-0.1 - 1.1j: 0}: pass
3311
+ case {-0: 0, 0 + 1j: 0, 0.1 + 1j: 0}: pass
3312
+ """ )
3313
+ expected_constants = (
3314
+ 0 ,
3315
+ - 0.1 ,
3316
+ complex (0 , - 0 ),
3317
+ complex (1 , 2 ),
3318
+ complex (1 , - 2 ),
3319
+ complex (1.1 , 2.1 ),
3320
+ complex (1.1 , - 2.1 ),
3321
+ complex (- 0 , 1 ),
3322
+ complex (- 0 , - 1 ),
3323
+ complex (- 0.1 , 1.1 ),
3324
+ complex (- 0.1 , - 1.1 ),
3325
+ (0 , ),
3326
+ (- 0.1 , ),
3327
+ (complex (0 , - 0 ), ),
3328
+ (complex (1 , 2 ), ),
3329
+ (complex (1 , - 2 ), ),
3330
+ (complex (1.1 , 2.1 ), ),
3331
+ (complex (1.1 , - 2.1 ), ),
3332
+ (complex (- 0 , 1 ), ),
3333
+ (complex (- 0 , - 1 ), ),
3334
+ (complex (- 0.1 , 1.1 ), ),
3335
+ (complex (- 0.1 , - 1.1 ), ),
3336
+ (0 , complex (0 , 1 ), complex (0.1 , 1 ))
3337
+ )
3338
+ consts = iter (expected_constants )
3339
+ tree = ast .parse (source , optimize = 1 )
3340
+ match_stmt = tree .body [0 ]
3341
+ for case in match_stmt .cases :
3342
+ pattern = case .pattern
3343
+ if isinstance (pattern , ast .MatchValue ):
3344
+ self .assertIsInstance (pattern .value , ast .Constant )
3345
+ self .assertEqual (pattern .value .value , next (consts ))
3346
+ elif isinstance (pattern , ast .MatchMapping ):
3347
+ keys = iter (next (consts ))
3348
+ for key in pattern .keys :
3349
+ self .assertIsInstance (key , ast .Constant )
3350
+ self .assertEqual (key .value , next (keys ))
3351
+ else :
3352
+ self .fail (f"Expected ast.MatchValue or ast.MatchMapping, found: { type (pattern )} " )
3353
+
3301
3354
3302
<
4182
code>3355 if __name__ == '__main__' :
3303
3356
if len (sys .argv ) > 1 and sys .argv [1 ] == '--snapshot-update' :
0 commit comments