10000 bpo-45876: Correctly rounded stdev() and pstdev() for the Decimal case by rhettinger · Pull Request #29828 · python/cpython · GitHub
[go: up one dir, main page]

Skip to content

bpo-45876: Correctly rounded stdev() and pstdev() for the Decimal case #29828

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Dec 1, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
bbd2da9
Merge pull request #1 from python/master
rhettinger Mar 16, 2021
74bdf1b
Merge branch 'master' of github.com:python/cpython
rhettinger Mar 22, 2021
6c53f1a
Merge branch 'master' of github.com:python/cpython
rhettinger Mar 22, 2021
a487c4f
.
rhettinger Mar 24, 2021
eb56423
.
rhettinger Mar 25, 2021
cc7ba06
.
rhettinger Mar 26, 2021
d024dd0
.
rhettinger Apr 22, 2021
b10f912
merge
rhettinger May 5, 2021
fb6744d
merge
rhettinger May 6, 2021
7f21a1c
Merge branch 'main' of github.com:python/cpython
rhettinger Aug 15, 2021
7da42d4
Merge branch 'main' of github.com:rhettinger/cpython
rhettinger Aug 25, 2021
e31757b
Merge branch 'main' of github.com:python/cpython
rhettinger Aug 31, 2021
f058a6f
Merge branch 'main' of github.com:python/cpython
rhettinger Aug 31, 2021
1fc29bd
Merge branch 'main' of github.com:python/cpython
rhettinger Sep 4, 2021
e5c0184
Merge branch 'main' of github.com:python/cpython
rhettinger Oct 30, 2021
3c86ec1
Merge branch 'main' of github.com:python/cpython
rhettinger Nov 9, 2021
96675e4 8000
Merge branch 'main' of github.com:rhettinger/cpython
rhettinger Nov 9, 2021
de558c6
Merge branch 'main' of github.com:python/cpython
rhettinger Nov 9, 2021
418a07f
Merge branch 'main' of github.com:python/cpython
rhettinger Nov 14, 2021
ea23a8b
Merge branch 'main' of github.com:python/cpython
rhettinger Nov 21, 2021
ba248b7
Merge branch 'main' of github.com:python/cpython
rhettinger Nov 27, 2021
037b5fe
Correctly rounded stdev results for Decimal inputs
rhettinger Nov 29, 2021
f5c091c
Whitespace
rhettinger Nov 29, 2021
70cdade
Rename the functions consistently
rhettinger Nov 29, 2021
82dbec6
Improve comment
rhettinger Nov 29, 2021
1a6c58d
Tweak variable names
rhettinger Nov 29, 2021
b2385b0
Replace Fraction arithmetic with integer arithmetic
rhettinger Nov 29, 2021
594ea27
Add spacing between terms
rhettinger Nov 29, 2021
3911581
Fix type annotation
rhettinger Nov 29, 2021
a09e3c4
Return a Decimal zero when the numerator is zero
rhettinger Nov 29, 2021
152ed3f
Remove unused import
rhettinger Nov 29, 2021
80371c1
Factor lhs of inequality. Rename helper function for consistency.
rhettinger Nov 29, 2021
1c86e7c
Add comment for future work.
rhettinger Nov 29, 2021
0684fac
Fix typo in docstring. Refine wording in comment.
rhettinger Nov 29, 2021
8b5e377
Add more detail to the comment about numerator and denominator sizes
rhettinger Nov 30, 2021
d11d567
Improve variable name
rhettinger Nov 30, 2021
309cb0a
Avoid double rounding in test code
rhettinger Nov 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Correctly rounded stdev results for Decimal inputs
  • Loading branch information
rhettinger committed Nov 29, 2021
commit 037b5fea25bf59df4e513992a7cc34a79d18425b
33 changes: 30 additions & 3 deletions Lib/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,18 @@ def _fail_neg(values, errmsg='negative value'):
raise StatisticsError(errmsg)
yield x


def _isqrt_frac_rto(n: int, m: int) -> float:
"""Square root of n/m, rounded to the nearest integer using round-to-odd."""
# Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf
a = math.isqrt(n // m)
return a | (a*a*m != n)


# For 53 bit precision floats, the _sqrt_frac() shift is 109.
_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3


def _sqrt_frac(n: int, m: int) -> float:
"""Square root of n/m as a float, correctly rounded."""
# See principle and proof sketch at: https://bugs.python.org/msg407078
Expand All @@ -327,6 +330,31 @@ def _sqrt_frac(n: int, m: int) -> float:
return numerator / denominator # Convert to float


def _deci_sqrt(n: int, m: int) -> Decimal:
"""Square root of n/m as a float, correctly rounded."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copypasta: "as a float"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

# Premise: For decimal, computing sqrt(m / n) can be off by 1 ulp.
# Method: Check the result, moving up or down a step if needed.
if not n:
return 0.0

f_square = Fraction(n, m)

d_mid = (Decimal(n) / Decimal(m)).sqrt()
f_mid = Fraction(*d_mid.as_integer_ratio())

d_plus = d_mid.next_plus()
f_plus = Fraction(*d_plus.as_integer_ratio())
if f_square > ((f_mid + f_plus) / 2) ** 2:
return d_plus

d_minus = d_mid.next_minus()
f_minus = Fraction(*d_minus.as_integer_ratio())
if f_square < ((f_mid + f_minus) / 2) ** 2:
return d_minus

return d_mid


# === Measures of central tendency (averages) ===

def mean(data):
Expand Down Expand Up @@ -888,9 +916,8 @@ def pstdev(data, mu=None):
raise StatisticsError('pstdev requires at least one data point')
T, ss = _ss(data, mu)
mss = ss / n
if hasattr(T, 'sqrt'):
var = _convert(mss, T)
return var.sqrt()
if issubclass(T, Decimal):
return _deci_sqrt(mss.numerator, mss.denominator)
return _sqrt_frac(mss.numerator, mss.denominator)


Expand Down
37 changes: 36 additions & 1 deletion Lib/test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2219,7 +2219,42 @@ def is_root_correctly_rounded(x: Fraction, root: float) -> bool:
statistics._sqrt_frac(1, 0)

# The result is well defined if both inputs are negative
self.assertAlmostEqual(statistics._sqrt_frac(-2, -1), math.sqrt(2.0))
self.assertEqual(statistics._sqrt_frac(-2, -1), statistics._sqrt_frac(2, 1))

def test_deci_sqrt(self):
root: Decimal
numerator: int
denominator: int

for root, numerator, denominator in [
(Decimal('0.4481904599041192673635338663'), 200874688349065940678243576378, 1000000000000000000000000000000), # No adj
(Decimal('0.7924949131383786609961759598'), 628048187350206338833590574929, 1000000000000000000000000000000), # Adj up
(Decimal('0.8500554152289934068192208727'), 722594208960136395984391238251, 1000000000000000000000000000000), # Adj down
]:
with decimal.localcontext(decimal.DefaultContext):
self.assertEqual(statistics._deci_sqrt(numerator, denominator), root)

# Confirm expected root with a quad precision decimal computation
with decimal.localcontext(decimal.DefaultContext) as ctx:
ctx.prec *= 4
high_prec_root = (Decimal(numerator) / Decimal(denominator)).sqrt()
with decimal.localcontext(decimal.DefaultContext):
target_root = +high_prec_root
self.assertEqual(root, target_root)

# Verify that corner cases and error handling match Decimal.sqrt()
self.assertEqual(statistics._deci_sqrt(0, 1), 0.0)
with self.assertRaises(decimal.InvalidOperation):
statistics._deci_sqrt(-1, 1)
with self.assertRaises(decimal.InvalidOperation):
statistics._deci_sqrt(1, -1)

# Error handling for zero denominator matches that for Fraction(1, 0)
with self.assertRaises(ZeroDivisionError):
statistics._deci_sqrt(1, 0)

# The result is well defined if both inputs are negative
self.assertEqual(statistics._deci_sqrt(-2, -1), statistics._deci_sqrt(2, 1))


class TestStdev(VarianceStdevMixin, NumericTestCase):
Expand Down
0