diff --git a/Lib/lib2to3/patcomp.py b/Lib/lib2to3/patcomp.py index f57f4954b26ce7..12c3222a5c46ee 100644 --- a/Lib/lib2to3/patcomp.py +++ b/Lib/lib2to3/patcomp.py @@ -145,11 +145,16 @@ def compile_basic(self, nodes, repeat=None): elif node.type == token.NAME: value = node.value if value.isupper(): - if value not in TOKEN_MAP: - raise PatternSyntaxError("Invalid token: %r" % value) + # Map named tokens to the type value for a LeafPattern + if value == 'TOKEN': + type = None + else: + type = getattr(token, value) + if not type: + raise PatternSyntaxError("Invalid token: %r" % value) if nodes[1:]: raise PatternSyntaxError("Can't have details for token") - return pytree.LeafPattern(TOKEN_MAP[value]) + return pytree.LeafPattern(type) else: if value == "any": type = None @@ -175,14 +180,8 @@ def get_int(self, node): return int(node.value) -# Map named tokens to the type value for a LeafPattern -TOKEN_MAP = {"NAME": token.NAME, - "STRING": token.STRING, - "NUMBER": token.NUMBER, - "TOKEN": None} - - def _type_of_literal(value): + # Special case: you can't match ASYNC or AWAIT in their new keyword forms this way. if value[0].isalpha(): return token.NAME elif value in grammar.opmap: diff --git a/Lib/lib2to3/tests/data/fixers/myfixes/fix_await.py b/Lib/lib2to3/tests/data/fixers/myfixes/fix_await.py new file mode 100644 index 00000000000000..933ce46f39b1a0 --- /dev/null +++ b/Lib/lib2to3/tests/data/fixers/myfixes/fix_await.py @@ -0,0 +1,13 @@ +from lib2to3.fixer_base import BaseFix +from lib2to3.fixer_util import Name + +class FixAwait(BaseFix): + """ + Find calls to await and change their target. + """ + + PATTERN = """power < [AWAIT] name='b' any* >""" + + def transform(self, node, results): + name = results["name"] + name.replace(Name("bar", name.prefix)) diff --git a/Lib/lib2to3/tests/test_refactor.py b/Lib/lib2to3/tests/test_refactor.py index 9e3b8fbb90b2f3..07a1f037f4f85c 100644 --- a/Lib/lib2to3/tests/test_refactor.py +++ b/Lib/lib2to3/tests/test_refactor.py @@ -37,7 +37,7 @@ def tearDown(self): def check_instances(self, instances, classes): for inst, cls in zip(instances, classes): if not isinstance(inst, cls): - self.fail("%s are not instances of %s" % instances, classes) + self.fail("%s are not instances of %s" % (instances, classes)) def rt(self, options=None, fixers=_DEFAULT_FIXERS, explicit=None): return refactor.RefactoringTool(fixers, options, explicit) @@ -55,7 +55,7 @@ def test_write_unchanged_files_option(self): self.assertTrue(rt.write_unchanged_files) def test_fixer_loading_helpers(self): - contents = ["explicit", "first", "last", "parrot", "preorder"] + contents = ["await", "explicit", "first", "last", "parrot", "preorder"] non_prefixed = refactor.get_all_fix_names("myfixes") prefixed = refactor.get_all_fix_names("myfixes", False) full_names = refactor.get_fixers_from_package("myfixes") @@ -132,6 +132,7 @@ class SimpleFix(fixer_base.BaseFix): def test_fixer_loading(self): from myfixes.fix_first import FixFirst + from myfixes.fix_await import FixAwait from myfixes.fix_last import FixLast from myfixes.fix_parrot import FixParrot from myfixes.fix_preorder import FixPreorder @@ -140,7 +141,7 @@ def test_fixer_loading(self): pre, post = rt.get_fixers() self.check_instances(pre, [FixPreorder]) - self.check_instances(post, [FixFirst, FixParrot, FixLast]) + self.check_instances(post, [FixFirst, FixAwait, FixParrot, FixLast]) def test_naughty_fixers(self): self.assertRaises(ImportError, self.rt, fixers=["not_here"]) @@ -331,3 +332,9 @@ def test_explicit(self): break else: self.fail("explicit fixer not loaded") + + def test_refactor_await(self): + rt = self.rt() + input = "async def a(): await b()\n\n" + tree = rt.refactor_string(input, "") + self.assertNotEqual(str(tree), input)