@@ -3227,115 +3227,65 @@ def test_folding_type_param_in_type_alias(self):
3227
3227
self .assert_ast (result_code , non_optimized_target , optimized_target )
3228
3228
3229
3229
def test_folding_match_case_allowed_expressions (self ):
3230
- source = textwrap .dedent ("""
3231
- match 0:
3232
- case -0: pass
3233
- case -0.1: pass
3234
- case -0j: pass
3235
- case -0.1j: pass
3236
- case 1 + 2j: pass
3237
- case 1 - 2j: pass
3238
- case 1.1 + 2.1j: pass
3239
- case 1.1 - 2.1j: pass
3240
- case -0 + 1j: pass
3241
- case -0 - 1j: pass
3242
- case -0.1 + 1.1j: pass
3243
- case -0.1 - 1.1j: pass
3244
- case {-0: 0}: pass
3245
- case {-0.1: 0}: pass
3246
- case {-0j: 0}: pass
3247
- case {-0.1j: 0}: pass
3248
- case {1 + 2j: 0}: pass
3249
- case {1 - 2j: 0}: pass
3250
- case {1.1 + 2.1j: 0}: pass
3251
- case {1.1 - 2.1j: 0}: pass
3252
- case {-0 + 1j: 0}: pass
3253
- case {-0 - 1j: 0}: pass
3254
- case {-0.1 + 1.1j: 0}: pass
3255
- case {-0.1 - 1.1j: 0}: pass
3256
- case {-0: 0, 0 + 1j: 0, 0.1 + 1j: 0}: pass
3257
- case [-0, -0.1, -0j, -0.1j]: pass
3258
- case (-0, -0.1, -0j, -0.1j): pass
3259
- case [[-0, -0.1], [-0j, -0.1j]]: pass
3260
- case ((-0, -0.1), (-0j, -0.1j)): pass
3261
- """ )
3262
- expected_constants = (
3263
- 0 ,
3264
- - 0.1 ,
3265
- complex (0 , - 0 ),
3266
- complex (0 , - 0.1 ),
3267
- complex (1 , 2 ),
3268
- complex (1 , - 2 ),
3269
- complex (1.1 , 2.1 ),
3270
- complex (1.1 , - 2.1 ),
3271
- complex (- 0 , 1 ),
3272
- complex (- 0 , - 1 ),
3273
- complex (- 0.1 , 1.1 ),
3274
- complex (- 0.1 , - 1.1 ),
3275
- (0 , ),
3276
- (- 0.1 , ),
3277
- (complex (0 , - 0 ), ),
3278
- (complex (0 , - 0.1 ), ),
3279
- (complex (1 , 2 ), ),
3280
- (complex (1 , - 2 ), ),
3281
- (complex (1.1 , 2.1 ), ),
3282
- (complex (1.1 , - 2.1 ), ),
3283
- (complex (- 0 , 1 ), ),
3284
- (complex (- 0 , - 1 ), ),
3285
- (complex (- 0.1 , 1.1 ), ),
3286
- (complex (- 0.1 , - 1.1 ), ),
3287
- (0 , complex (0 , 1 ), complex (0.1 , 1 )),
3288
- (
3289
- 0 ,
3290
- - 0.1 ,
3291
- complex (0 , - 0 ),
3292
- complex (0 , - 0.1 ),
3293
- ),
3294
- (
3295
- 0 ,
3296
- - 0.1 ,
3297
- complex (0 , - 0 ),
3298
- complex (0 , - 0.1 ),
3299
- ),
3300
- (
3301
- 0 ,
3302
- - 0.1 ,
3303
- complex (0 , - 0 ),
3304
- complex (0 , - 0.1 ),
3305
- ),
3306
- (
3307
- 0 ,
3308
- - 0.1 ,
3309
- complex (0 , - 0 ),
3310
- complex (0 , - 0.1 ),
3311
- )
3312
- )
3313
- consts = iter (expected_constants )
3314
- tree = ast .parse (source , optimize = 1 )
3315
- match_stmt = tree .body [0 ]
3316
- for case in match_stmt .cases :
3317
- pattern = case .pattern
3318
- if isinstance (pattern , ast .MatchValue ):
3319
- self .assertIsInstance (pattern .value , ast .Constant )
3320
- self .assertEqual (pattern .value .value , next (consts ))
3321
- elif isinstance (pattern , ast .MatchMapping ):
3322
- keys = iter (next (consts ))
3323
- for key in pattern .keys :
3324
- self .assertIsInstance (key , ast .Constant )
3325
- self .assertEqual (key .value , next (keys ))
3326
- elif isinstance (pattern , ast .MatchSequence ):
3327
- values = iter (next (consts ))
3328
- for pat in pattern .patterns :
3329
- if isinstance (pat , ast .MatchValue ):
3330
- self .assertEqual (pat .value .value , next (values ))
3331
- elif isinstance (pat , ast .MatchSequence ):
3332
- for p in pat .patterns :
3333
- self .assertIsInstance (p , ast .MatchValue )
3334
- self .assertEqual (p .value .value , next (values ))
3335
- else :
3336
- self .fail (f"Expected ast.MatchValue or ast.MatchSequence, found: { type (pat )} " )
3230
+ def get_match_case_values (node ):
3231
+ result = []
3232
+ if isinstance (node , ast .Constant ):
3233
+ result .append (node .value )
3234
+ elif isinstance (node , ast .MatchValue ):
3235
+ result .extend (get_match_case_values (node .value ))
3236
+ elif isinstance (node , ast .MatchMapping ):
3237
+ for key in node .keys :
3238
+ result .extend (get_match_case_values (key ))
3239
+ elif isinstance (node , ast .MatchSequence ):
3240
+ for pat in node .patterns :
3241
+ result .extend (get_match_case_values (pat ))
3337
3242
else :
3338
- self .fail (f"Expected ast.MatchValue or ast.MatchMapping, found: { type (pattern )} " )
3243
+ self .fail (f"Unexpected node { node } " )
3244
+ return result
3245
+
3246
+ tests = [
3247
+ ("-0" , [0 ]),
3248
+ ("-0.1" , [- 0.1 ]),
3249
+ ("-0j" , [complex (0 , 0 )]),
3250
+ ("-0.1j" , [complex (0 , - 0.1 )]),
3251
+ ("1 + 2j" , [complex (1 , 2 )]),
3252
+ ("1 - 2j" , [complex (1 , - 2 )]),
3253
+ ("1.1 + 2.1j" , [complex (1.1 , 2.1 )]),
3254
+ ("1.1 - 2.1j" , [complex (1.1 , - 2.1 )]),
3255
+ ("-0 + 1j" , [complex (0 , 1 )]),
3256
+ ("-0 - 1j" , [complex (0 , - 1 )]),
3257
+ ("-0.1 + 1.1j" , [complex (- 0.1 , 1.1 )]),
3258
+ ("-0.1 - 1.1j" , [complex (- 0.1 , - 1.1 )]),
3259
+ ("{-0: 0}" , [0 ]),
3260
+ ("{-0.1: 0}" , [- 0.1 ]),
3261
+ ("{-0j: 0}" , [complex (0 , 0 )]),
3262
+ ("{-0.1j: 0}" , [complex (0 , - 0.1 )]),
3263
+ ("{1 + 2j: 0}" , [complex (1 , 2 )]),
3264
+ ("{1 - 2j: 0}" , [complex (1 , - 2 )]),
3265
+ ("{1.1 + 2.1j: 0}" , [complex (1.1 , 2.1 )]),
3266
+ ("{1.1 - 2.1j: 0}" , [complex (1.1 , - 2.1 )]),
3267
+ ("{-0 + 1j: 0}" , [complex (0 , 1 )]),
3268
+ ("{-0 - 1j: 0}" , [complex (0 , - 1 )]),
3269
+ ("{-0.1 + 1.1j: 0}" , [complex (- 0.1 , 1.1 )]),
3270
+ ("{-0.1 - 1.1j: 0}" , [complex (- 0.1 , - 1.1 )]),
3271
+ ("{-0: 0, 0 + 1j: 0, 0.1 + 1j: 0}" , [0 , complex (0 , 1 ), complex (0.1 , 1 )]),
3272
+ ("[-0, -0.1, -0j, -0.1j]" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3273
+ ("[[[[-0, -0.1, -0j, -0.1j]]]]" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3274
+ ("[[-0, -0.1], -0j, -0.1j]" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3275
+ ("[[-0, -0.1], [-0j, -0.1j]]" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3276
+ ("(-0, -0.1, -0j, -0.1j)" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3277
+ ("((((-0, -0.1, -0j, -0.1j))))" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3278
+ ("((-0, -0.1), -0j, -0.1j)" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3279
+ ("((-0, -0.1), (-0j, -0.1j))" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3280
+ ]
3281
+ for match_expr , constants in tests :
3282
+ with self .subTest (match_expr ):
3283
+ src = f"match 0:\n \t case { match_expr } : pass"
3284
+ tree = ast .parse (src , optimize = 1 )
3285
+ match_stmt = tree .body [0 ]
3286
+ case = match_stmt .cases [0 ]
3287
+ values = get_match_case_values (case .pattern )
3288
+ self .assertListEqual (constants , values )
3339
3289
3340
3290
3341
3291
if __name__ == '__main__' :
0 commit comments