8000 gh-63416: Speed up assertEqual on long sequences by jdevries3133 · Pull Request #27434 · python/cpython · GitHub
[go: up one dir, main page]

Skip to content

gh-63416: Speed up assertEqual on long sequences #27434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
20fde31
Fix bpo-19217
eamanu Dec 17, 2018
6996d8a
fix unnecessary space and solve @gpshead comments
eamanu Dec 18, 2018
adfcf6d
fix test
Dec 18, 2018
ddebcfa
Merge remote-tracking branch 'temp/fix_bpo-19217' into bpo-19217-slow…
jdevries3133 Jul 22, 2021
81900d4
bpo-19217: fix failing test in unittest's test suite
jdevries3133 Jul 22, 2021
e5cb7bc
bpo-19217: add blurb
jdevries3133 Jul 22, 2021
6955d81
revert changes to difflib.py
jdevries3133 Jul 22, 2021
46f5e29
add regression test
jdevries3133 Jul 22, 2021
847bfc4
Merge branch 'main' of github.com:python/cpython into bpo-19217-slow-…
jdevries3133 Aug 3, 2021
856029d
draft implementation of unittest.case._heuristic_diff
jdevries3133 Aug 4, 2021
1dd1bcd
fix: remove now-unused imports from test_case.py
jdevries3133 Aug 4, 2021
1f45b28
merge upstream python/cpython:main
jdevries3133 Aug 12, 2021
988740a
add variably scaled test cases, misc updates & revisions
jdevries3133 Aug 12, 2021
4661f75
move Test_HeuristicDiff beneath main tests
jdevries3133 Aug 12, 2021
a9d23c4
remove unnecessary list comprehension in Lib/unittest/case.py
jdevries3133 Aug 13, 2021
192d7a4
spelling error in Lib/unittest/test/test_case.py
jdevries3133 Aug 13, 2021
74895e9
implement second review from @ambv
jdevries3133 Aug 13, 2021
e21dbe3
merge changes from PR suggestions
jdevries3133 Aug 13, 2021
5e8a186
Merge branch 'main' of github.com:python/cpython into bpo-19217-slow-…
jdevries3133 Jan 29, 2022
28cb042
fix from @JelleZijlstra
jdevries3133 Jan 29, 2022
63ecc45
fix from @JelleZijlstra
jdevries3133 Jan 29, 2022
e4344cb
remove unnecessary type checker supression
jdevries3133 Jan 29, 2022
0bdf06c
fix typo
jdevries3133 Jan 29, 2022
05ffdf2
simplify tuple syntax
jdevries3133 Jan 29, 2022
5767d21
fix news entry "~lists~ => sequenecs"
jdevries3133 Jan 29, 2022
b9f2f9d
better document the reasoning behind the heuristic
jdevries3133 Jan 29, 2022
5351506
thanks @JelleZijlstra
jdevries3133 Jan 29, 2022
872de08
thanks @JelleZijlstra
jdevries3133 Jan 29, 2022
180c732
fix whitespace around keyword argument
jdevries3133 Jan 29, 2022
200a1a7
Merge branch 'bpo-19217-slow-assertEq' of github.com:jdevries3133/cpy…
jdevries3133 Jan 29, 2022
fc752cf
update blurb to describe observable changes
jdevries3133 Feb 7, 2022
b53738f
Merge branch 'main' of github.com:python/cpython into bpo-19217-slow-…
jdevries3133 Feb 7, 2022
adc7d6f
reword to make less promises
gpshead Feb 8, 2022
c4d6f38
Merge branch 'main' of github.com:python/cpython into bpo-19217-slow-…
jdevries3133 Feb 8, 2022
9e6a464
Merge branch 'bpo-19217-slow-assertEq' of github.com:jdevries3133/cpy…
jdevries3133 Feb 8, 2022
0ad029f
Merge remote-tracking branch 'upstream/main' into bpo-19217-slow-asse…
AA-Turner Apr 9, 2025
4c5175c
Remove annotations
AA-Turner Apr 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions Lib/unittest/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import difflib
import pprint
import re
from collections.abc import Iterator
import warnings
import collections
import contextlib
Expand Down Expand Up @@ -170,6 +171,46 @@ def _is_subtype(expected, basetype):
return all(_is_subtype(e, basetype) for e in expected)
return isinstance(expected, type) and issubclass(expected, basetype)


def _heuristic_diff(a: list[str], b: list[str]) -> Iterator[str]:
"""After testing the magnitude of the inputs, preferably return the output
of difflib.ndiff, but fallback to difflib.unified_diff for prohibitively
expensive inputs. How cost is calculated:

Cost is calculated according to this heuristic:

cost = (number of differing lines
* total length of all differing lines)

This heuristic is used because the time complexity of ndiff is
approximately O((diff)^2), where `diff` is the product of the number of
differing lines, and the total length of differing lines. On the other
hand, unified_diff's cost is the same as the cost of producing `diff`
by itself: O(a + b).
Copy link
Member
@tim-one tim-one Feb 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this isn't true. Both have worst-case quadratic time, but on different kinds of input, and it's less likely to get provoked in the simpler kind of differencing unified_diff tries.

Well, actually actually 😉, the possible worst case of ndiff is worse than that, at least cubic time.


See bpo-19217 for additional context.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
See bpo-19217 for additional context.
See gh-63416 for additional context.

"""
COST_LIMIT = 1_000_000

# call unified_diff
udiff = list(difflib.unified_diff(a, b, fromfile="expected", tofile="got"))
udiff_differing_lines = [l for l in udiff
if l.startswith(('-', '+'))]
Comment on lines +216 to +217
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
udiff_differing_lines = [l for l in udiff
if l.startswith(('-', '+'))]
udiff_differing_lines = [l for l in udiff if l.startswith(('-', '+'))]


# inspect unified_diff output
num_difflines = len(udiff_differing_lines)
total_diffline_length = sum(len(l) for l in udiff_differing_lines)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
total_diffline_length = sum(len(l) for l in udiff_differing_lines)
total_diffline_length = sum(map(len, udiff_differing_lines))


# now, we know what it will cost to call `ndiff`, according to the
# heuristic
diff_cost = num_difflines * total_diffline_length

if diff_cost > COST_LIMIT:
yield from udiff
else:
yield from difflib.ndiff(a, b)


class _BaseTestCaseContext:

def __init__(self, test_case):
Expand Down Expand Up @@ -1020,9 +1061,9 @@ def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None):
differing += ('Unable to index element %d '
'of second %s\n' % (len1, seq_type_name))
standardMsg = differing
diffMsg = '\n' + '\n'.join(
difflib.ndiff(pprint.pformat(seq1).splitlines(),
pprint.pformat(seq2).splitlines()))
diffMsg = _heuristic_diff(pprint.pformat(seq1).splitlines(),
pprint.pformat(seq2).splitlines())
diffMsg = '\n' + '\n'.join(diffMsg)

standardMsg = self._truncateMessage(standardMsg, diffMsg)
msg = self._formatMessage(msg, standardMsg)
Expand Down Expand Up @@ -1133,7 +1174,7 @@ def assertDictEqual(self, d1, d2, msg=None):

if d1 != d2:
standardMsg = '%s != %s' % _common_shorten_repr(d1, d2)
diff = ('\n' + '\n'.join(difflib.ndiff(
diff = ('\n' + '\n'.join(_heuristic_diff(
pprint.pformat(d1).splitlines(),
pprint.pformat(d2).splitlines())))
standardMsg = self._truncateMessage(standardMsg, diff)
Expand Down Expand Up @@ -1216,7 +1257,7 @@ def assertMultiLineEqual(self, first, second, msg=None):
firstlines = [first + '\n']
secondlines = [second + '\n']
standardMsg = '%s != %s' % _common_shorten_repr(first, second)
diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines))
diff = '\n' + ''.join(_heuristic_diff(firstlines, secondlines))
standardMsg = self._truncateMessage(standardMsg, diff)
self.fail(self._formatMessage(msg, standardMsg))

Expand Down
236 changes: 234 additions & 2 deletions Lib/unittest/test/test_case.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
from dataclasses import dataclass
import difflib
import pprint
import pickle
Expand All @@ -9,6 +10,7 @@
import weakref
import inspect
import types
from typing import Iterator

from copy import deepcopy
from test import support
Expand Down Expand Up @@ -821,8 +823,9 @@ def testAssertSequenceEqualMaxDiff(self):
self.assertEqual(self.maxDiff, 80*8)
seq1 = 'a' + 'x' * 80**2
seq2 = 'b' + 'x' * 80**2
diff = '\n'.join(difflib.ndiff(pprint.pformat(seq1).splitlines(),
pprint.pformat(seq2).splitlines()))
diff = unittest.case._heuristic_diff(pprint.pformat(seq1).splitlines(),
pprint.pformat(seq2).splitlines())
diff = '\n'.join(diff)
# the +1 is the leading \n added by assertSequenceEqual
omitted = unittest.case.DIFF_OMITTED % (len(diff) + 1,)

Expand Down Expand Up @@ -1973,5 +1976,234 @@ def test2(self):
self.assertEqual(MyException.ninstance, 0)


class Test_HeuristicDiff(unittest.TestCase):

# this large constant coerces the use of `unified_diff` for several tests
N = 50_000

@staticmethod
def is_unified_diff(diff: Iterator[str]) -> bool:
"""Check for the presence of the @@ ... @@ diff summary line."""
diffstr = ''.join(diff)
p = r'@@ -(\d(,)?(\d)?)+ \+(\d(,)?(\d)?)+ @@'
mo = re. 23D3 search(p, diffstr)
return bool(mo)

def test_is_unified_diff(self):
"""Test the helper above"""
ud = difflib.unified_diff('foo', 'bar')
nd = difflib.ndiff('foo', 'bar')
self.assertTrue(self.is_unified_diff(ud))
self.assertFalse(self.is_unified_diff(nd))

def assertHeuristicDiffReturns(self, a, b, expect: tuple[str, ...]):
"""check that _heuristic_diff(a, b) == expect"""
diff_iterable = unittest.case._heuristic_diff(a, b)
diff = tuple(diff_iterable)
self.assertTrue(diff == expect)

def test_ndiff_is_used_with_small_inputs(self):
a = ('foo',)
b = ('bar',)
expect = ('- foo', '+ bar')
self.assertHeuristicDiffReturns(a, b, expect)

def test_unified_diff_is_used_with_large_inputs(self):
"""One long line, as well as many single-character lines."""
# one long line
a = ('foo' * self.N,)
b = ('bar' * self.N,)
expect = ('--- expected\n', '+++ got\n', '@@ -1 +1 @@\n',
'-' + 'foo' * self.N,
'+' + 'bar' * self.N)
self.assertHeuristicDiffReturns(a, b, expect)

# many lines
a = ('1\n' * self.N).splitlines()
b = ('2\n' * self.N).splitlines()
expect = ('--- expected\n', '+++ got\n', f'@@ -1,{self.N} +1,{self.N} @@\n',
*(['-1'] * self.N),
*(['+2'] * self.N))
self.assertHeuristicDiffReturns(a, b, expect)

def test_ndiff_to_unified_diff_breaking_point_long_line(self):
"""This is the approximate single line length at which the heuristic
will switch from ndiff to unified_diff."""
expect_switch_at = 125_000
a = ''
b = ''
n = expect_switch_at // 2
while '@' not in ''.join(unittest.case._heuristic_diff(a, b)):
n *= 1.3
a = ('a' * int(n),)
b = ('b' * int(n),)

self.assertGreater(n, expect_switch_at * 0.8)
self.assertLess(n, expect_switch_at * 1.1)

def test_ndiff_to_unified_diff_breaking_point_many_lines(self):
"""For lines just one character long, the heuristic will switch from
ndiff to unified_diff around 70,000 differing lines."""
expect_switch_at = 70_000
a = ''
b = ''
n = expect_switch_at // 2
while '@' not in ''.join(unittest.case._heuristic_diff(a, b)):
n *= 1.3
a = ('a\n' * int(n),)
b = ('b\n' * int(n),)

self.assertGreater(n, expect_switch_at * 0.8)
self.assertLess(n, expect_switch_at * 1.1)

def test_ndiff_to_unified_diff_scaled_line_and_cols(self):
"""Scale line length and number of differing columns at different
rates, expecting a switch to `unified_diff` at specified points.
"""

@dataclass
class Case:
"""Case class specifies parameters for all tests."""
line_length_factor: int
num_lines_factor: int
extent_differing: float

# this is a "magic number" for all cases where the heuristic
# will switch from using ndiff to unified_diff.
expect_unified_diff_at: int

# --- Layout test cases
# ---------------------

cases = (
# scale width and length by ratios of 2:1
Case(
line_length_factor=1,
num_lines_factor=2,
extent_differing=1,
expect_unified_diff_at=22,
),
Case(
line_length_factor=2,
num_lines_factor=1,
extent_differing=1,
expect_unified_diff_at=28,
),

# scale width and length by ratios of 3:1
Case(
line_length_factor=1,
num_lines_factor=3,
extent_differing=1,
expect_unified_diff_at=16,
),
Case(
line_length_factor=3,
num_lines_factor=1,
extent_differing=1,
expect_unified_diff_at=24,
),

# # scale by ratios of 3:1, with only 40% differing
Case(
line_length_factor=3,
num_lines_factor=1,
extent_differing=1,
expect_unified_diff_at=24,
),
Case(
line_length_factor=1,
num_lines_factor=3,
< 10000 /td> extent_differing=0.4,
expect_unified_diff_at=16,
),
Case(
line_length_factor=3,
num_lines_factor=1,
extent_differing=0.4,
expect_unified_diff_at=23,
),
)

# --- Execute test cases
# ----------------------

def run_case(case: Case, N):
"""Given one of the test cases above, execute the test case for a
given `N` constant value. Check if the test has passed as
specified."""

# we are working our way towards _heuristic_diff(foo, bar)

# --- Construct Differing Strings ---

# construct foo. Double line count because bar will have twice
# as many lines (lines of 'a' and lines of 'b')
foo = (
('a' * N * (case.line_length_factor),) # create line
* (N * case.num_lines_factor * 2) # duplicate line
)

# construct bar
bar_a_line = ('a' * (N * case.line_length_factor))
bar_b_line = ('b' * (N * case.line_length_factor))
bar_a_lines = ((bar_a_line,) * (N * case.num_lines_factor))
if case.extent_differing != 1:
# diminish the amount of 'b' by case.extent_differing, and add
# additional 'a' at the end as padding
bar_b_lines = ((bar_b_line,)
* int(N
* case.num_lines_factor
* case.extent_differing))
bar_a_padding = ((bar_a_line,)
* int(N
* case.num_lines_factor
* (1 - case.extent_differing)))
bar = (*bar_a_lines, *bar_b_lines, *bar_a_padding)
else:
bar_b_lines = ((bar_b_line,) * int(N * case.num_lines_factor))
bar = (*bar_a_lines, *bar_b_lines)

# --- Perform Diff ---

# after all that, we have `foo` and `bar`; two string sequences
# with differences as specified by the Case parameters, scaled
# by a factor of `N`.
diff = unittest.case._heuristic_diff(foo, bar)

# --- Make Assertions ---

# now, check that the `case.expect_unified_diff_at` condition was
# met
if (
N < case.expect_unified_diff_at
and self.is_unified_diff(diff)
):
self.fail('Switched to `unified_diff` prematurely. Expected '
f'switch at {case.expect_unified_diff_at}, but '
f'actually switched at {N} for the case {case}')
elif (
N > case.expect_unified_diff_at
and not self.is_unified_diff(diff)
):
self.fail('Switch to `unified_diff` did not occur. Expected '
'switch at {case.expect_unified_diff_at}, but no '
f'switch occured when N == {N}')

for case in cases:
for N in range(10, case.expect_unified_diff_at + 1):
run_case(case, N)

def test_ndiff_is_always_used_for_similar_sequences(self):
"""ndiff is perfectly efficient at showing small diffs. As long as
the difference between `a` and `b` are small, the size of `a` and `b`
should not disqualify the use of ndiff."""
a = ('foo ' * 5 + '\n') * 10_000
b = ('foo ' * 5 + '\n') * 9_999 + ('bar ' * 5 + '\n')

diff = unittest.case._heuristic_diff(a, b)
self.assertFalse(self.is_unified_diff(diff))


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Optimize :meth:`~unittest.TestCase.assertEqual` method for long sequences of varied
items.
0