8000 bpo-15987: Implement ast.compare · python/cpython@cfb508f · GitHub
[go: up one dir, main page]

Skip to content

Commit cfb508f

Browse files
committed
bpo-15987: Implement ast.compare
1 parent d8ff44c commit cfb508f

File tree

5 files changed

+190
-5
lines changed

5 files changed

+190
-5
lines changed

Doc/library/ast.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,6 +1748,17 @@ and classes for traversing abstract syntax trees:
17481748
Added the *indent* option.
17491749

17501750

1751+
.. function:: compare(a, b, /, *, compare_types=True, compare_fields=True, compare_attributes=False)
1752+
1753+
Recursively compare given 2 AST nodes. If *compare_types* is ``False``, the
1754+
field values won't be checked whether they belong to same type or not. If
1755+
*compare_fields* is ``True``, members of ``_fields`` attribute on both node's
1756+
type will be checked. If *compare_attributes* is ``True``, members of
1757+
``_attributes`` attribute on both node's will be compared.
1758+
1759+
.. versionadded:: 3.9
1760+
1761+
17511762
.. _ast-cli:
17521763

17531764
Command-Line Usage

Doc/whatsnew/3.9.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ that would produce an equivalent :class:`ast.AST` object when parsed.
167167
Added docstrings to AST nodes that contains the ASDL signature used to
168168
construct that node. (Contributed by Batuhan Taskaya in :issue:`39638`.)
169169

170+
Added :func:`ast.compare` for comparing 2 AST node.
171+
(Contributed by Batuhan Taskaya in :issue:`15987`)
172+
170173
asyncio
171174
-------
172175

Lib/ast.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,69 @@ def walk(node):
368368
yield node
369369

370370

371+
def compare(
372+
a,
373+
b,
374+
/,
375+
*,
376+
compare_types=True,
377+
compare_fields=True,
378+
compare_attributes=False,
379+
):
380+
"""
381+
Compares recursively given two ast nodes. If *compare_types* is False, the
382+
field values won't be checked whether they belong to same type or not. If
383+
*compare_fields* is True, members of `_fields` attribute on both node's type
384+
will be checked. If *compare_attributes* is True, members of `_attributes`
385+
attribute on both node's will be compared.
386+
"""
387+
388+
def _compare(a, b):
389+
if compare_types and type(a) is not type(b):
390+
return False
391+
elif isinstance(a, AST):
392+
return compare(
393+
a,
394+
b,
395+
compare_attributes=compare_attributes,
396+
compare_fields=compare_fields,
397+
compare_types=compare_types,
398+
)
399+
elif isinstance(a, list):
400+
if len(a) != len(b):
401+
return False
402+
for a_item, b_item in zip(a, b):
403+
if not _compare(a_item, b_item):
404+
return False
405+
else:
406+
return True
407+
else:
408+
return a == b
409+
410+
def _compare_member(member):
411+
for field in getattr(a, member):
412+
if not hasattr(a, field) and not hasattr(b, field):
413+
continue
414+
if not (hasattr(a, field) and hasattr(b, field)):
415+
return False
416+
a_field = getattr(a, field)
417+
b_field = getattr(b, field)
418+
if not _compare(a_field, b_field):
419+
return False
420+
else:
421+
return True
422+
423+
if type(a) is not type(b):
424+
return False
425+
if compare_fields:
426+
if not _compare_member("_fields"):
427+
return False
428+
if compare_attributes:
429+
if not _compare_member("_attributes"):
430+
return False
431+
return True
432+
433+
371434
class NodeVisitor(object):
372435
"""
373436
A node visitor base class that walks the abstract syntax tree and calls a

Lib/test/test_ast.py

Lines changed: 111 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import ast
22
import dis
33
import os
4+
import random
45
import sys
6+
import tokenize
57
import unittest
68
import warnings
79
import weakref
@@ -25,6 +27,9 @@ def to_tuple(t):
2527
result.append(to_tuple(getattr(t, f)))
2628
return tuple(result)
2729

30+
STDLIB = os.path.dirname(ast.__file__)
31+
STDLIB_FILES = [fn for fn in os.listdir(STDLIB) if fn.endswith(".py")]
32+
STDLIB_FILES.extend(["test/test_grammar.py", "test/test_unpack_ex.py"])
2833

2934
# These tests are compiled through "exec"
3035
# There should be at least one test per statement
@@ -654,6 +659,110 @@ def test_ast_asdl_signature(self):
654659
expressions[0] = f"expr = {ast.expr.__subclasses__()[0].__doc__}"
655660
self.assertCountEqual(ast.expr.__doc__.split("\n"), expressions)
656661

662+
self.assertTrue(ast.compare(ast.parse("x = 10"), ast.parse("x = 10")))
663+
self.assertFalse(ast.compare(ast.parse("x = 10"), ast.parse("")))
664+
self.assertFalse(ast.compare(ast.parse("x = 10"), ast.parse("x")))
665+
self.assertFalse(
666+
ast.compare(ast.parse("x = 10;y = 20"), ast.parse("class C:pass"))
667+
)
668+
669+
def test_compare_literals(self):
670+
constants = (
671+
-20,
672+
20,
673+
20.0,
674+
1,
675+
1.0,
676+
True,
677+
0,
678+
False,
679+
frozenset(),
680+
tuple(),
681+
"ABCD",
682+
"abcd",
683+
"中文字",
684+
1e1000,
685+
-1e1000,
686+
)
687+
for next_index, constant in enumerate(constants[:-1], 1):
688+
next_constant = constants[next_index]
689+
with self.subTest(literal=constant, next_literal=next_constant):
690+
self.assertTrue(
691+
ast.compare(ast.Constant(constant), ast.Constant(constant))
692+
)
693+
self.assertFalse(
694+
ast.compare(
695+
ast.Constant(constant), ast.Constant(next_constant)
696+
)
697+
)
698+
699+
same_looking_literal_cases = [
700+
{1, 1.0, True, 1 + 0j},
701+
{0, 0.0, False, 0 + 0j},
702+
]
703+
for same_looking_literals in same_looking_literal_cases:
704+
for literal in same_looking_literals:
705+
for same_looking_literal in same_looking_literals - {literal}:
706+
self.assertFalse(
707+
ast.compare(
708+
ast.Constant(literal),
709+
ast.Constant(same_looking_literal),
710+
)
711+
)
712+
713+
def test_compare_fieldless(self):
714+
self.assertTrue(ast.compare(ast.Add(), ast.Add()))
715+
self.assertFalse(ast.compare(ast.Sub(), ast.Add()))
716+
self.assertFalse(ast.compare(ast.Sub(), ast.Constant()))
717+
718+
def test_compare_stdlib(self):
719+
if support.is_resource_enabled("cpu"):
720+
files = STDLIB_FILES
721+
else:
722+
files = random.sample(STDLIB_FILES, 10)
723+
724+
for module in files:
725+
with self.subTest(module):
726+
fn = os.path.join(STDLIB, module)
727+
with tokenize.open(fn) as fp:
728+
source = fp.read()
729+
a = ast.parse(source, fn)
730+
b = ast.parse(source, fn)
731+
self.assertTrue(
732+
ast.compare(a, b), f"{ast.dump(a)} != {ast.dump(b)}"
733+
)
734+
735+
def test_compare_tests(self):
736+
for mode, sources in (
737+
("exec", exec_tests),
738+
("eval", eval_tests),
739+
("single", single_tests),
740+
):
741+
for source in sources:
742+
a = ast.parse(source, mode=mode)
743+
b = ast.parse(source, mode=mode)
744+
self.assertTrue(
745+
ast.compare(a, b), f"{ast.dump(a)} != {ast.dump(b)}"
746+
)
747+
748+
def test_compare_options(self):
749+
def parse(a, b):
750+
return ast.parse(a), ast.parse(b)
751+
752+
a, b = parse("2 + 2", "2+2")
753+
self.assertTrue(ast.compare(a, b, compare_attributes=False))
754+
self.assertFalse(ast.compare(a, b, compare_attributes=True))
755+
756+
a, b = parse("1", "1.0")
757+
self.assertTrue(ast.compare(a, b, compare_types=False))
758+
self.assertFalse(ast.compare(a, b, compare_types=True))
759+
760+
a, b = parse("1", "2")
761+
self.assertTrue(ast.compare(a, b, compare_fields=False, compare_attributes=False))
762+
self.assertTrue(ast.compare(a, b, compare_fields=False, compare_attributes=True))
763+
self.assertFalse(ast.compare(a, b, compare_fields=True, compare_attributes=False))
764+
self.assertFalse(ast.compare(a, b, compare_fields=True, compare_attributes=True))
765+
657766

658767
class ASTHelpers_Test(unittest.TestCase):
659768
maxDiff = None
@@ -1369,12 +1478,9 @@ def test_nameconstant(self):
13691478
self.expr(ast.NameConstant(4))
13701479

13711480
def test_stdlib_validates(self):
1372-
stdlib = os.path.dirname(ast.__file__)
1373-
tests = [fn for fn in os.listdir(stdlib) if fn.endswith(".py")]
1374-
tests.extend(["test/test_grammar.py", "test/test_unpack_ex.py"])
1375-
for module in tests:
1481+
for module in STDLIB_FILES:
13761482
with self.subTest(module):
1377-
fn = os.path.join(stdlib, module)
1483+
fn = os.path.join(STDLIB, module)
13781484
with open(fn, "r", encoding="utf-8") as fp:
13791485
source = fp.read()
13801486
mod = ast.parse(source, fn)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Implemented :func:`ast.compare` for comparing 2 AST node. Patch by Batuhan
2+
Taskaya.

0 commit comments

Comments
 (0)
0