8000 add match case folding tests · python/cpython@0fa7c4c · GitHub
[go: up one dir, main page]

Skip to content

Commit 0fa7c4c

Browse files
committed
add match case folding tests
1 parent a6babab commit 0fa7c4c

File tree

1 file changed

+68
-15
lines changed

1 file changed

+68
-15
lines changed

Lib/test/test_ast/test_ast.py

Lines changed: 68 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3081,21 +3081,6 @@ def test_cli_file_input(self):
30813081

30823082

30833083
class ASTOptimiziationTests(unittest.TestCase):
3084-
binop = {
3085-
"+": ast.Add(),
3086-
"-": ast.Sub(),
3087-
"*": ast.Mult(),
3088-
"/": ast.Div(),
3089-
"%": ast.Mod(),
3090-
"<<": ast.LShift(),
3091-
">>": ast.RShift(),
3092-
"|": ast.BitOr(),
3093-
"^": ast.BitXor(),
3094-
"&": ast.BitAnd(),
3095-
"//": ast.FloorDiv(),
3096-
"**": ast.Pow(),
3097-
}
3098-
30993084
unaryop = {
31003085
"~": ast.Invert(),
31013086
"+": ast.UAdd(),
@@ -3298,6 +3283,74 @@ def test_folding_type_param_in_type_alias(self):
32983283
)
32993284
self.assert_ast(result_code, non_optimized_target, optimized_target)
33003285

3286+
def test_folding_match_case_allowed_expressions(self):
3287+
source = textwrap.dedent("""
3288+
match 0:
3289+
case -0: pass
3290+
case -0.1: pass
3291+
case -0j: pass
3292+
case 1 + 2j: pass
3293+
case 1 - 2j: pass
3294+
case 1.1 + 2.1j: pass
3295+
case 1.1 - 2.1j: pass
3296+
case -0 + 1j: pass
3297+
case -0 - 1j: pass
3298+
case -0.1 + 1.1j: pass
3299+
case -0.1 - 1.1j: pass
3300+
case {-0: 0}: pass
3301+
case {-0.1: 0}: pass
3302+
case {-0j: 0}: pass
3303+
case {1 + 2j: 0}: pass
3304+
case {1 - 2j: 0}: pass
3305+
case {1.1 + 2.1j: 0}: pass
3306+
case {1.1 - 2.1j: 0}: pass
3307+
case {-0 + 1j: 0}: pass
3308+
case {-0 - 1j: 0}: pass
3309+
case {-0.1 + 1.1j: 0}: pass
3310+
case {-0.1 - 1.1j: 0}: pass
3311+
case {-0: 0, 0 + 1j: 0, 0.1 + 1j: 0}: pass
3312+
""")
3313+
expected_constants = (
3314+
0,
3315+
-0.1,
3316+
complex(0, -0),
3317+
complex(1, 2),
3318+
complex(1, -2),
3319+
complex(1.1, 2.1),
3320+
complex(1.1, -2.1),
3321+
complex(-0, 1),
3322+
complex(-0, -1),
3323+
complex(-0.1, 1.1),
3324+
complex(-0.1, -1.1),
3325+
(0, ),
3326+
(-0.1, ),
3327+
(complex(0, -0), ),
3328+
(complex(1, 2), ),
3329+
(complex(1, -2), ),
3330+
(complex(1.1, 2.1), ),
3331+
(complex(1.1, -2.1), ),
3332+
(complex(-0, 1), ),
3333+
(complex(-0, -1), ),
3334+
(complex(-0.1, 1.1), ),
3335+
(complex(-0.1, -1.1), ),
3336+
(0, complex(0, 1), complex(0.1, 1))
3337+
)
3338+
consts = iter(expected_constants)
3339+
tree = ast.parse(source, optimize=1)
3340+
match_stmt = tree.body[0]
3341+
for case in match_stmt.cases:
3342+
pattern = case.pattern
3343+
if isinstance(pattern, ast.MatchValue):
3344+
self.assertIsInstance(pattern.value, ast.Constant)
3345+
self.assertEqual(pattern.value.value, next(consts))
3346+
elif isinstance(pattern, ast.MatchMapping):
3347+
keys = iter(next(consts))
3348+
for key in pattern.keys:
3349+
self.assertIsInstance(key, ast.Constant)
3350+
self.assertEqual(key.value, next(keys))
3351+
else:
3352+
self.fail(f"Expected ast.MatchValue or ast.MatchMapping, found: {type(pattern)}")
3353+
33013354

3302< 4182 code>3355
if __name__ == '__main__':
33033356
if len(sys.argv) > 1 and sys.argv[1] == '--snapshot-update':

0 commit comments

Comments
 (0)
0