From dbeda9330141ac2a1a89d1fec6cf25ab4f16cd0c Mon Sep 17 00:00:00 2001 From: Tim Hatch Date: Mon, 11 Mar 2019 18:01:48 -0700 Subject: [PATCH] Support matching ASYNC and AWAIT as keywords. With this change, you can match 'await' to mean the NAME token possible with <3.8, and the AWAIT token for the proper position. Before this change, your only option is to match TOKEN and filter in your fixer. As an example of code improved by this, Bowler can now be more explicit about what it expects to find here, as seen at https://github.com/facebookincubator/Bowler/blob/da3ec9a88c41d050d591e0c9e15b44849f0c9631/bowler/query.py#L338 I did a cursory reading of https://bugs.python.org/issue30406 and https://bugs.python.org/issue35975 and think this should be compatible. --- Lib/lib2to3/patcomp.py | 19 +++++++++---------- .../tests/data/fixers/myfixes/fix_await.py | 13 +++++++++++++ Lib/lib2to3/tests/test_refactor.py | 13 ++++++++++--- 3 files changed, 32 insertions(+), 13 deletions(-) create mode 100644 Lib/lib2to3/tests/data/fixers/myfixes/fix_await.py diff --git a/Lib/lib2to3/patcomp.py b/Lib/lib2to3/patcomp.py index f57f4954b26ce7..12c3222a5c46ee 100644 --- a/Lib/lib2to3/patcomp.py +++ b/Lib/lib2to3/patcomp.py @@ -145,11 +145,16 @@ def compile_basic(self, nodes, repeat=None): elif node.type == token.NAME: value = node.value if value.isupper(): - if value not in TOKEN_MAP: - raise PatternSyntaxError("Invalid token: %r" % value) + # Map named tokens to the type value for a LeafPattern + if value == 'TOKEN': + type = None + else: + type = getattr(token, value) + if not type: + raise PatternSyntaxError("Invalid token: %r" % value) if nodes[1:]: raise PatternSyntaxError("Can't have details for token") - return pytree.LeafPattern(TOKEN_MAP[value]) + return pytree.LeafPattern(type) else: if value == "any": type = None @@ -175,14 +180,8 @@ def get_int(self, node): return int(node.value) -# Map named tokens to the type value for a LeafPattern -TOKEN_MAP = {"NAME": token.NAME, - "STRING": token.STRING, - "NUMBER": token.NUMBER, - "TOKEN": None} - - def _type_of_literal(value): + # Special case: you can't match ASYNC or AWAIT in their new keyword forms this way. if value[0].isalpha(): return token.NAME elif value in grammar.opmap: diff --git a/Lib/lib2to3/tests/data/fixers/myfixes/fix_await.py b/Lib/lib2to3/tests/data/fixers/myfixes/fix_await.py new file mode 100644 index 00000000000000..933ce46f39b1a0 --- /dev/null +++ b/Lib/lib2to3/tests/data/fixers/myfixes/fix_await.py @@ -0,0 +1,13 @@ +from lib2to3.fixer_base import BaseFix +from lib2to3.fixer_util import Name + +class FixAwait(BaseFix): + """ + Find calls to await and change their target. + """ + + PATTERN = """power < [AWAIT] name='b' any* >""" + + def transform(self, node, results): + name = results["name"] + name.replace(Name("bar", name.prefix)) diff --git a/Lib/lib2to3/tests/test_refactor.py b/Lib/lib2to3/tests/test_refactor.py index 9e3b8fbb90b2f3..07a1f037f4f85c 100644 --- a/Lib/lib2to3/tests/test_refactor.py +++ b/Lib/lib2to3/tests/test_refactor.py @@ -37,7 +37,7 @@ def tearDown(self): def check_instances(self, instances, classes): for inst, cls in zip(instances, classes): if not isinstance(inst, cls): - self.fail("%s are not instances of %s" % instances, classes) + self.fail("%s are not instances of %s" % (instances, classes)) def rt(self, options=None, fixers=_DEFAULT_FIXERS, explicit=None): return refactor.RefactoringTool(fixers, options, explicit) @@ -55,7 +55,7 @@ def test_write_unchanged_files_option(self): self.assertTrue(rt.write_unchanged_files) def test_fixer_loading_helpers(self): - contents = ["explicit", "first", "last", "parrot", "preorder"] + contents = ["await", "explicit", "first", "last", "parrot", "preorder"] non_prefixed = refactor.get_all_fix_names("myfixes") prefixed = refactor.get_all_fix_names("myfixes", False) full_names = refactor.get_fixers_from_package("myfixes") @@ -132,6 +132,7 @@ class SimpleFix(fixer_base.BaseFix): def test_fixer_loading(self): from myfixes.fix_first import FixFirst + from myfixes.fix_await import FixAwait from myfixes.fix_last import FixLast from myfixes.fix_parrot import FixParrot from myfixes.fix_preorder import FixPreorder @@ -140,7 +141,7 @@ def test_fixer_loading(self): pre, post = rt.get_fixers() self.check_instances(pre, [FixPreorder]) - self.check_instances(post, [FixFirst, FixParrot, FixLast]) + self.check_instances(post, [FixFirst, FixAwait, FixParrot, FixLast]) def test_naughty_fixers(self): self.assertRaises(ImportError, self.rt, fixers=["not_here"]) @@ -331,3 +332,9 @@ def test_explicit(self): break else: self.fail("explicit fixer not loaded") + + def test_refactor_await(self): + rt = self.rt() + input = "async def a(): await b()\n\n" + tree = rt.refactor_string(input, "") + self.assertNotEqual(str(tree), input)