8000 bpo-45876: Improve accuracy for stdev() and pstdev() in statistics (G… · python/cpython@af9ee57 · GitHub
[go: up one dir, main page]

Skip to content

Commit af9ee57

Browse files
bpo-45876: Improve accuracy for stdev() and pstdev() in statistics (GH-29736)
* Inlined code from variance functions * Added helper functions for the float square root of a fraction * Call helper functions * Add blurb * Fix over-specified test * Add a test for the _sqrt_frac() helper function * Increase the tested range * Add type hints to the internal function. * Fix test for correct rounding * Simplify ⌊√(n/m)⌋ calculation Co-authored-by: Mark Dickinson <dickinsm@gmail.com> * Add comment and beef-up tests * Test for zero denominator * Add algorithmic references * Add test for the _isqrt_frac_rto() helper function. * Compute the 109 instead of hard-wiring it * Stronger test for _isqrt_frac_rto() * Bigger range * Bigger range * Replace float() call with int/int division to be parallel with the other code path. * Factor out division. Update proof link. Remove internal type declaration Co-authored-by: Mark Dickinson <dickinsm@gmail.com>
8000
1 parent db55f3f commit af9ee57

File tree

3 files changed

+107
-16
lines changed

3 files changed

+107
-16
lines changed

Lib/statistics.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@
130130
import math
131131
import numbers
132132
import random
133+
import sys
133134

134135
from fractions import Fraction
135136
from decimal import Decimal
@@ -304,6 +305,27 @@ def _fail_neg(values, errmsg='negative value'):
304305
raise StatisticsError(errmsg)
305306
yield x
306307

308+
def _isqrt_frac_rto(n: int, m: int) -> float:
309+
"""Square root of n/m, rounded to the nearest integer using round-to-odd."""
310+
# Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf
311+
a = math.isqrt(n // m)
312+
return a | (a*a*m != n)
313+
314+
# For 53 bit precision floats, the _sqrt_frac() shift is 109.
315+
_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3
316+
317+
def _sqrt_frac(n: int, m: int) -> float:
318+
"""Square root of n/m as a float, correctly rounded."""
319+
# See principle and proof sketch at: https://bugs.python.org/msg407078
320+
q = (n.bit_length() - m.bit_length() - _sqrt_shift) // 2
321+
if q >= 0:
322+
numerator = _isqrt_frac_rto(n, m << 2 * q) << q
323+
denominator = 1
324+
else:
325+
numerator = _isqrt_frac_rto(n << -2 * q, m)
326+
denominator = 1 << -q
327+
return numerator / denominator # Convert to float
328+
307329

308330
# === Measures of central tendency (averages) ===
309331

@@ -837,14 +859,17 @@ def stdev(data, xbar=None):
837859
1.0810874155219827
838860
839861
"""
840-
# Fixme: Despite the exact sum of squared deviations, some inaccuracy
841-
# remain because there are two rounding steps. The first occurs in
842-
# the _convert() step for variance(), the second occurs in math.sqrt().
843-
var = variance(data, xbar)
844-
try:
862+
if iter(data) is data:
863+
data = list(data)
864+
n = len(data)
865+
if n < 2:
866+
raise StatisticsError('stdev requires at least two data points')
867+
T, ss = _ss(data, xbar)
868+
mss = ss / (n - 1)
869+
if hasattr(T, 'sqrt'):
870+
var = _convert(mss, T)
845871
return var.sqrt()
846-
except AttributeError:
847-
return math.sqrt(var)
872+
return _sqrt_frac(mss.numerator, mss.denominator)
848873

849874

850875
def pstdev(data, mu=None):
@@ -856,14 +881,17 @@ def pstdev(data, mu=None):
856881
0.986893273527251
857882
858883
"""
859-
# Fixme: Despite the exact sum of squared deviations, some inaccuracy
860-
# remain because there are two rounding steps. The first occurs in
861-
# the _convert() step for pvariance(), the second occurs in math.sqrt().
862-
var = pvariance(data, mu)
863-
try:
884+
if iter(data) is data:
885+
data = list(data)
886+
n = len(data)
887+
if n < 1:
888+
raise StatisticsError('pstdev requires at least one data point')
889+
T, ss = _ss(data, mu)
890+
mss = ss / n
891+
if hasattr(T, 'sqrt'):
892+
var = _convert(mss, T)
864893
return var.sqrt()
865-
except AttributeError:
866-
return math.sqrt(var)
894+
return _sqrt_frac(mss.numerator, mss.denominator)
867895

868896

869897
# === Statistics for relations between two inputs ===

Lib/test/test_statistics.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
import copy
1010
import decimal
1111
import doctest
12+
import itertools
1213
import math
1314
import pickle
1415
import random
1516
import sys
1617
import unittest
1718
from test import support
18-
from test.support import import_helper
19+
from test.support import import_helper, requires_IEEE_754
1920

2021
from decimal import Decimal
2122
from fractions import Fraction
@@ -2161,6 +2162,66 @@ def test_center_not_at_mean(self):
21612162
self.assertEqual(self.func(data), 2.5)
21622163
self.assertEqual(self.func(data, mu=0.5), 6.5)
21632164

2165+
class TestSqrtHelpers(unittest.TestCase):
2166+
2167+
def test_isqrt_frac_rto(self):
2168+
for n, m in itertools.product(range(100), range(1, 1000)):
2169+
r = statistics._isqrt_frac_rto(n, m)
2170+
self.assertIsInstance(r, int)
2171+
if r*r*m == n:
2172+
# Root is exact
2173+
continue
2174+
# Inexact, so the root should be odd
2175+
self.assertEqual(r&1, 1)
2176+
# Verify correct rounding
2177+
self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2)
2178+
2179+
@requires_IEEE_754
2180+
def test_sqrt_frac(self):
2181+
2182+
def is_root_correctly_rounded(x: Fraction, root: float) -> bool:
2183+
if not x:
2184+
return root == 0.0
2185+
2186+
# Extract adjacent representable floats
2187+
r_up: float = math.nextafter(root, math.inf)
2188+
r_down: float = math.nextafter(root, -math.inf)
2189+
assert r_down < root < r_up
2190+
2191+
# Convert to fractions for exact arithmetic
2192+
frac_root: Fraction = Fraction(root)
2193+
half_way_up: Fraction = (frac_root + Fraction(r_up)) / 2
2194+
half_way_down: Fraction = (frac_root + Fraction(r_down)) / 2
2195+
2196+
# Check a closed interval.
2197+
# Does not test for a midpoint rounding rule.
2198+
return half_way_down ** 2 <= x <= half_way_up ** 2
2199+
2200+
randrange = random.randrange
2201+
2202+
for i in range(60_000):
2203+
numerator: int = randrange(10 ** randrange(50))
2204+
denonimator: int = randrange(10 ** randrange(50)) + 1
2205+
with self.subTest(numerator=numerator, denonimator=denonimator):
2206+
x: Fraction = Fraction(numerator, denonimator)
2207+
root: float = statistics._sqrt_frac(numerator, denonimator)
2208+
self.assertTrue(is_root_correctly_rounded(x, root))
2209+
2210+
# Verify that corner cases and error handling match math.sqrt()
2211+
self.assertEqual(statistics._sqrt_frac(0, 1), 0.0)
2212+
with self.assertRaises(ValueError):
2213+
statistics._sqrt_frac(-1, 1)
2214+
with self.assertRaises(ValueError):
2215+
statistics._sqrt_frac(1, -1)
2216+
2217+
# Error handling for zero denominator matches that for Fraction(1, 0)
2218+
with self.assertRaises(ZeroDivisionError):
2219+
statistics._sqrt_frac(1, 0)
2220+
2221+
# The result is well defined if both inputs are negative
2222+
self.assertAlmostEqual(statistics._sqrt_frac(-2, -1), math.sqrt(2.0))
2223+
2224+
21642225
class TestStdev(VarianceStdevMixin, NumericTestCase):
21652226
# Tests for sample standard deviation.
21662227
def setUp(self):
@@ -2175,7 +2236,7 @@ def test_compare_to_variance(self):
21752236
# Test that stdev is, in fact, the square root of variance.
21762237
data = [random.uniform(-2, 9) for _ in range(1000)]
21772238
expected = math.sqrt(statistics.variance(data))
2178-
self.assertEqual(self.func(data), expected)
2239+
self.assertAlmostEqual(self.func(data), expected)
21792240

21802241
def test_center_not_at_mean(self):
21812242
data = (1.0, 2.0)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Improve the accuracy of stdev() and pstdev() in the statistics module. When
2+
the inputs are floats or fractions, the output is a correctly rounded float

0 commit comments

Comments
 (0)
0