diff --git a/Lib/test/test_unittest/test_case.py b/Lib/test/test_unittest/test_case.py index a04af55f3fc0ae..e73e90deca54d2 100644 --- a/Lib/test/test_unittest/test_case.py +++ b/Lib/test/test_unittest/test_case.py @@ -1,4 +1,5 @@ import contextlib +from dataclasses import dataclass import difflib import pprint import pickle @@ -950,8 +951,9 @@ def testAssertSequenceEqualMaxDiff(self): self.assertEqual(self.maxDiff, 80*8) seq1 = 'a' + 'x' * 80**2 seq2 = 'b' + 'x' * 80**2 - diff = '\n'.join(difflib.ndiff(pprint.pformat(seq1).splitlines(), - pprint.pformat(seq2).splitlines())) + diff = unittest.case._heuristic_diff(pprint.pformat(seq1).splitlines(), + pprint.pformat(seq2).splitlines()) + diff = '\n'.join(diff) # the +1 is the leading \n added by assertSequenceEqual omitted = unittest.case.DIFF_OMITTED % (len(diff) + 1,) @@ -2325,5 +2327,234 @@ def test2(self): self.assertEqual(MyException.ninstance, 0) +class Test_HeuristicDiff(unittest.TestCase): + + # this large constant coerces the use of `unified_diff` for several tests + N = 50_000 + + @staticmethod + def is_unified_diff(diff): + """Check for the presence of the @@ ... @@ diff summary line.""" + diffstr = ''.join(diff) + p = r'@@ -(\d(,)?(\d)?)+ \+(\d(,)?(\d)?)+ @@' + mo = re.search(p, diffstr) + return bool(mo) + + def test_is_unified_diff(self): + """Test the helper above""" + ud = difflib.unified_diff('foo', 'bar') + nd = difflib.ndiff('foo', 'bar') + self.assertTrue(self.is_unified_diff(ud)) + self.assertFalse(self.is_unified_diff(nd)) + + def assertHeuristicDiffReturns(self, a, b, expect): + """check that _heuristic_diff(a, b) == expect""" + diff_iterable = unittest.case._heuristic_diff(a, b) + diff = tuple(diff_iterable) + self.assertTrue(diff == expect) + + def test_ndiff_is_used_with_small_inputs(self): + a = ('foo',) + b = ('bar',) + expect = ('- foo', '+ bar') + self.assertHeuristicDiffReturns(a, b, expect) + + def test_unified_diff_is_used_with_large_inputs(self): + """One long line, as well as many single-character lines.""" + # one long line + a = ('foo' * self.N,) + b = ('bar' * self.N,) + expect = ('--- expected\n', '+++ got\n', '@@ -1 +1 @@\n', + '-' + 'foo' * self.N, + '+' + 'bar' * self.N) + self.assertHeuristicDiffReturns(a, b, expect) + + # many lines + a = ('1\n' * self.N).splitlines() + b = ('2\n' * self.N).splitlines() + expect = ('--- expected\n', '+++ got\n', f'@@ -1,{self.N} +1,{self.N} @@\n', + *(['-1'] * self.N), + *(['+2'] * self.N)) + self.assertHeuristicDiffReturns(a, b, expect) + + def test_ndiff_to_unified_diff_breaking_point_long_line(self): + """This is the approximate single line length at which the heuristic + will switch from ndiff to unified_diff.""" + expect_switch_at = 125_000 + a = '' + b = '' + n = expect_switch_at // 2 + while '@' not in ''.join(unittest.case._heuristic_diff(a, b)): + n *= 1.3 + a = ('a' * int(n),) + b = ('b' * int(n),) + + self.assertGreater(n, expect_switch_at * 0.8) + self.assertLess(n, expect_switch_at * 1.1) + + def test_ndiff_to_unified_diff_breaking_point_many_lines(self): + """For lines just one character long, the heuristic will switch from + ndiff to unified_diff around 70,000 differing lines.""" + expect_switch_at = 70_000 + a = '' + b = '' + n = expect_switch_at // 2 + while '@' not in ''.join(unittest.case._heuristic_diff(a, b)): + n *= 1.3 + a = ('a\n' * int(n),) + b = ('b\n' * int(n),) + + self.assertGreater(n, expect_switch_at * 0.8) + self.assertLess(n, expect_switch_at * 1.1) + + def test_ndiff_to_unified_diff_scaled_line_and_cols(self): + """Scale line length and number of differing columns at different + rates, expecting a switch to `unified_diff` at specified points. + """ + + @dataclass + class Case: + """Case class specifies parameters for all tests.""" + line_length_factor: int + num_lines_factor: int + extent_differing: float + + # this is a "magic number" for all cases where the heuristic + # will switch from using ndiff to unified_diff. + expect_unified_diff_at: int + + # --- Layout test cases + # --------------------- + + cases = ( + # scale width and length by ratios of 2:1 + Case( + line_length_factor=1, + num_lines_factor=2, + extent_differing=1, + expect_unified_diff_at=22, + ), + Case( + line_length_factor=2, + num_lines_factor=1, + extent_differing=1, + expect_unified_diff_at=28, + ), + + # scale width and length by ratios of 3:1 + Case( + line_length_factor=1, + num_lines_factor=3, + extent_differing=1, + expect_unified_diff_at=16, + ), + Case( + line_length_factor=3, + num_lines_factor=1, + extent_differing=1, + expect_unified_diff_at=24, + ), + + # # scale by ratios of 3:1, with only 40% differing + Case( + line_length_factor=3, + num_lines_factor=1, + extent_differing=1, + expect_unified_diff_at=24, + ), + Case( + line_length_factor=1, + num_lines_factor=3, + extent_differing=0.4, + expect_unified_diff_at=16, + ), + Case( + line_length_factor=3, + num_lines_factor=1, + extent_differing=0.4, + expect_unified_diff_at=23, + ), + ) + + # --- Execute test cases + # ---------------------- + + def run_case(case, N): + """Given one of the test cases above, execute the test case for a + given `N` constant value. Check if the test has passed as + specified.""" + + # we are working our way towards _heuristic_diff(foo, bar) + + # --- Construct Differing Strings --- + + # construct foo. Double line count because bar will have twice + # as many lines (lines of 'a' and lines of 'b') + foo = ( + ('a' * N * (case.line_length_factor),) # create line + * (N * case.num_lines_factor * 2) # duplicate line + ) + + # construct bar + bar_a_line = ('a' * (N * case.line_length_factor)) + bar_b_line = ('b' * (N * case.line_length_factor)) + bar_a_lines = ((bar_a_line,) * (N * case.num_lines_factor)) + if case.extent_differing != 1: + # diminish the amount of 'b' by case.extent_differing, and add + # additional 'a' at the end as padding + bar_b_lines = ((bar_b_line,) + * int(N + * case.num_lines_factor + * case.extent_differing)) + bar_a_padding = ((bar_a_line,) + * int(N + * case.num_lines_factor + * (1 - case.extent_differing))) + bar = (*bar_a_lines, *bar_b_lines, *bar_a_padding) + else: + bar_b_lines = ((bar_b_line,) * int(N * case.num_lines_factor)) + bar = (*bar_a_lines, *bar_b_lines) + + # --- Perform Diff --- + + # after all that, we have `foo` and `bar`; two string sequences + # with differences as specified by the Case parameters, scaled + # by a factor of `N`. + diff = unittest.case._heuristic_diff(foo, bar) + + # --- Make Assertions --- + + # now, check that the `case.expect_unified_diff_at` condition was + # met + if ( + N < case.expect_unified_diff_at + and self.is_unified_diff(diff) + ): + self.fail('Switched to `unified_diff` prematurely. Expected ' + f'switch at {case.expect_unified_diff_at}, but ' + f'actually switched at {N} for the case {case}') + elif ( + N > case.expect_unified_diff_at + and not self.is_unified_diff(diff) + ): + self.fail('Switch to `unified_diff` did not occur. Expected ' + 'switch at {case.expect_unified_diff_at}, but no ' + f'switch occured when N == {N}') + + for case in cases: + for N in range(10, case.expect_unified_diff_at + 1): + run_case(case, N) + + def test_ndiff_is_always_used_for_similar_sequences(self): + """ndiff is perfectly efficient at showing small diffs. As long as + the difference between `a` and `b` are small, the size of `a` and `b` + should not disqualify the use of ndiff.""" + a = ('foo ' * 5 + '\n') * 10_000 + b = ('foo ' * 5 + '\n') * 9_999 + ('bar ' * 5 + '\n') + + diff = unittest.case._heuristic_diff(a, b) + self.assertFalse(self.is_unified_diff(diff)) + + if __name__ == "__main__": unittest.main() diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index 10c3b7e122371e..45b3b105bdedb8 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -190,6 +190,46 @@ def _is_subtype(expected, basetype): return all(_is_subtype(e, basetype) for e in expected) return isinstance(expected, type) and issubclass(expected, basetype) + +def _heuristic_diff(a, b): + """After testing the magnitude of the inputs, preferably return the output + of difflib.ndiff, but fallback to difflib.unified_diff for prohibitively + expensive inputs. How cost is calculated: + + Cost is calculated according to this heuristic: + + cost = (number of differing lines + * total length of all differing lines) + + This heuristic is used because the time complexity of ndiff is + approximately O((diff)^2), where `diff` is the product of the number of + differing lines, and the total length of differing lines. On the other + hand, unified_diff's cost is the same as the cost of producing `diff` + by itself: O(a + b). + + See bpo-19217 for additional context. + """ + COST_LIMIT = 1_000_000 + + # call unified_diff + udiff = list(difflib.unified_diff(a, b, fromfile="expected", tofile="got")) + udiff_differing_lines = [l for l in udiff + if l.startswith(('-', '+'))] + + # inspect unified_diff output + num_difflines = len(udiff_differing_lines) + total_diffline_length = sum(len(l) for l in udiff_differing_lines) + + # now, we know what it will cost to call `ndiff`, according to the + # heuristic + diff_cost = num_difflines * total_diffline_length + + if diff_cost > COST_LIMIT: + yield from udiff + else: + yield from difflib.ndiff(a, b) + + class _BaseTestCaseContext: def __init__(self, test_case): @@ -1095,9 +1135,9 @@ def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None): differing += ('Unable to index element %d ' 'of second %s\n' % (len1, seq_type_name)) standardMsg = differing - diffMsg = '\n' + '\n'.join( - difflib.ndiff(pprint.pformat(seq1).splitlines(), - pprint.pformat(seq2).splitlines())) + diffMsg = _heuristic_diff(pprint.pformat(seq1).splitlines(), + pprint.pformat(seq2).splitlines()) + diffMsg = '\n' + '\n'.join(diffMsg) standardMsg = self._truncateMessage(standardMsg, diffMsg) msg = self._formatMessage(msg, standardMsg) @@ -1208,7 +1248,7 @@ def assertDictEqual(self, d1, d2, msg=None): if d1 != d2: standardMsg = '%s != %s' % _common_shorten_repr(d1, d2) - diff = ('\n' + '\n'.join(difflib.ndiff( + diff = ('\n' + '\n'.join(_heuristic_diff( pprint.pformat(d1).splitlines(), pprint.pformat(d2).splitlines()))) standardMsg = self._truncateMessage(standardMsg, diff) @@ -1277,7 +1317,7 @@ def assertMultiLineEqual(self, first, second, msg=None): # Generate the message and diff, then raise the exception standardMsg = '%s != %s' % _common_shorten_repr(first, second) - diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines)) + diff = '\n' + ''.join(_heuristic_diff(firstlines, secondlines)) standardMsg = self._truncateMessage(standardMsg, diff) self.fail(self._formatMessage(msg, standardMsg)) diff --git a/Misc/NEWS.d/next/Library/2021-07-21-22-57-51.bpo-19217.vm-cr-.rst b/Misc/NEWS.d/next/Library/2021-07-21-22-57-51.bpo-19217.vm-cr-.rst new file mode 100644 index 00000000000000..6814851d4d583b --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-07-21-22-57-51.bpo-19217.vm-cr-.rst @@ -0,0 +1,9 @@ +Optimize :meth:`~unittest.TestCase.assertEqual` for long sequences of varied +items. Based on an internal heuristic, the algorithm used to produce the diff +method will switch based on the relevant magnitude of inputs. As a result, diff +output after a failing assertion may appear differently for large inputs. +Specifically, unittest will internally switch from using :func:`difflib.ndiff` +(slow) to using :func:`difflib.unified_diff` (less likely to be slow). + +This optimization reduces that chance that non-linear time complexity of +diff algorithms do not cause a test suite's failing test to hang.