8000 GH-100485: Add math.sumprod() (GH-100677) · python/cpython@47b9f83 · GitHub
[go: up one dir, main page]

Skip to content

Commit 47b9f83

Browse files
authored
GH-100485: Add math.sumprod() (GH-100677)
1 parent deaf090 commit 47b9f83

File tree

6 files changed

+548
-10
lines changed

6 files changed

+548
-10
lines changed

Doc/library/itertools.rst

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ by combining :func:`map` and :func:`count` to form ``map(f, count())``.
3333
These tools and their built-in counterparts also work well with the high-speed
3434
functions in the :mod:`operator` module. For example, the multiplication
3535
operator can be mapped across two vectors to form an efficient dot-product:
36-
``sum(map(operator.mul, vector1, vector2))``.
36+
``sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))``.
3737

3838

3939
**Infinite iterators:**
@@ -838,10 +838,6 @@ which incur interpreter overhead.
838838
"Returns the sequence elements n times"
839839
return chain.from_iterable(repeat(tuple(iterable), n))
840840

841-
def dotproduct(vec1, vec2):
842-
"Compute a sum of products."
843-
return sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))
844-
845841
def convolve(signal, kernel):
846842
# See: https://betterexplained.com/articles/intuitive-convolution/
847843
# convolve(data, [0.25, 0.25, 0.25, 0.25]) --> Moving average (blur)
@@ -852,7 +848,7 @@ which incur interpreter overhead.
852848
window = collections.deque([0], maxlen=n) * n
853849
for x in chain(signal, repeat(0, n-1)):
854850
window.append(x)
855-
yield dotproduct(kernel, window)
851+
yield math.sumprod(kernel, window)
856852

857853
def polynomial_from_roots(roots):
858854
"""Compute a polynomial's coefficients from its roots.
@@ -1211,9 +1207,6 @@ which incur interpreter overhead.
12111207
>>> list(ncycles('abc', 3))
12121208
['a', 'b', 'c', 'a', 'b', 'c', 'a', 'b', 'c']
12131209

1214-
>>> dotproduct([1,2,3], [4,5,6])
1215-
32
1216-
12171210
>>> data = [20, 40, 24, 32, 20, 28, 16]
12181211
>>> list(convolve(data, [0.25, 0.25, 0.25, 0.25]))
12191212
[5.0, 15.0, 21.0, 29.0, 29.0, 26.0, 24.0, 16.0, 11.0, 4.0]

Doc/library/math.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,22 @@ Number-theoretic and representation functions
291291
.. versionadded:: 3.7
292292

293293

294+
.. function:: sumprod(p, q)
295+
296+
Return the sum of products of values from two iterables *p* and *q*.
297+
298+
Raises :exc:`ValueError` if the inputs do not have the same length.
299+
300+
Roughly equivalent to::
301+
302+
sum(itertools.starmap(operator.mul, zip(p, q, strict=true)))
303+
304< 6D40 /td>+
For float and mixed int/float inputs, the intermediate products
305+
and sums are computed with extended precision.
306+
307+
.. versionadded:: 3.12
308+
309+
294310
.. function:: trunc(x)
295311

296312
Return *x* with the fractional part

Lib/test/test_math.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from test.support import verbose, requires_IEEE_754
55
from test import support
66
import unittest
7+
import fractions
78
import itertools
89
import decimal
910
import math
@@ -1202,6 +1203,171 @@ def testLog10(self):
12021203
self.assertEqual(math.log(INF), INF)
12031204
self.assertTrue(math.isnan(math.log10(NAN)))
12041205

1206+
def testSumProd(self):
1207+
sumprod = math.sumprod
1208+
Decimal = decimal.Decimal
1209+
Fraction = fractions.Fraction
1210+
1211+
# Core functionality
1212+
self.assertEqual(sumprod(iter([10, 20, 30]), (1, 2, 3)), 140)
1213+
self.assertEqual(sumprod([1.5, 2.5], [3.5, 4.5]), 16.5)
1214+
self.assertEqual(sumprod([], []), 0)
1215+
1216+
# Type preservation and coercion
1217+
for v in [
1218+
(10, 20, 30),
1219+
(1.5, -2.5),
1220+
(Fraction(3, 5), Fraction(4, 5)),
1221+
(Decimal(3.5), Decimal(4.5)),
1222+
(2.5, 10), # float/int
1223+
(2.5, Fraction(3, 5)), # float/fraction
1224+
(25, Fraction(3, 5)), # int/fraction
1225+
(25, Decimal(4.5)), # int/decimal
1226+
]:
1227+
for p, q in [(v, v), (v, v[::-1])]:
1228+
with self.subTest(p=p, q=q):
1229+
expected = sum(p_i * q_i for p_i, q_i in zip(p, q, strict=True))
1230+
actual = sumprod(p, q)
1231+
self.assertEqual(expected, actual)
1232+
self.assertEqual(type(expected), type(actual))
1233+
1234+
# Bad arguments
1235+
self.assertRaises(TypeError, sumprod) # No args
1236+
self.assertRaises(TypeError, sumprod, []) # One arg
1237+
self.assertRaises(TypeError, sumprod, [], [], []) # Three args
1238+
self.assertRaises(TypeError, sumprod, None, [10]) # Non-iterable
1239+
self.assertRaises(TypeError, sumprod, [10], None) # Non-iterable
1240+
1241+
# Uneven lengths
1242+
self.assertRaises(ValueError, sumprod, [10, 20], [30])
1243+
self.assertRaises(ValueError, sumprod, [10], [20, 30])
1244+
1245+
# Error in iterator
1246+
def raise_after(n):
1247+
for i in range(n):
1248+
yield i
1249+
raise RuntimeError
1250+
with self.assertRaises(RuntimeError):
1251+
sumprod(range(10), raise_after(5))
1252+
with self.assertRaises(RuntimeError):
1253+
sumprod(raise_after(5), range(10))
1254+
1255+
# Error in multiplication
1256+
class BadMultiply:
1257+
def __mul__(self, other):
1258+
raise RuntimeError
1259+
def __rmul__(self, other):
1260+
raise RuntimeError
1261+
with self.assertRaises(RuntimeError):
1262+
sumprod([10, BadMultiply(), 30], [1, 2, 3])
1263+
with self.assertRaises(RuntimeError):
1264+
sumprod([1, 2, 3], [10, BadMultiply(), 30])
1265+
1266+
# Error in addition
1267+
with self.assertRaises(TypeError):
1268+
sumprod(['abc', 3], [5, 10])
1269+
with self.assertRaises(TypeError):
1270+
sumprod([5, 10], ['abc', 3])
1271+
1272+
# Special values should give the same as the pure python recipe
1273+
self.assertEqual(sumprod([10.1, math.inf], [20.2, 30.3]), math.inf)
1274+
self.assertEqual(sumprod([10.1, math.inf], [math.inf, 30.3]), math.inf)
1275+
self.assertEqual(sumprod([10.1, math.inf], [math.inf, math.inf]), math.inf)
1276+
self.assertEqual(sumprod([10.1, -math.inf], [20.2, 30.3]), -math.inf)
1277+
self.assertTrue(math.isnan(sumprod([10.1, math.inf], [-math.inf, math.inf])))
1278+
self.assertTrue(math.isnan(sumprod([10.1, math.nan], [20.2, 30.3])))
1279+
self.assertTrue(math.isnan(sumprod([10.1, math.inf], [math.nan, 30.3])))
1280+
self.assertTrue(math.isnan(sumprod([10.1, math.inf], [20.3, math.nan])))
1281+
1282+
# Error cases that arose during development
1283+
args = ((-5, -5, 10), (1.5, 4611686018427387904, 2305843009213693952))
1284+
self.assertEqual(sumprod(*args), 0.0)
1285+
1286+
1287+
@requires_IEEE_754
1288+
@unittest.skipIf(HAVE_DOUBLE_ROUNDING,
1289+
"sumprod() accuracy not guaranteed on machines with double rounding")
1290+
@support.cpython_only # Other implementations may choose a different algorithm
1291+
def test_sumprod_accuracy(self):
1292+
sumprod = math.sumprod
1293+
self.assertEqual(sumprod([0.1] * 10, [1]*10), 1.0)
1294+
self.assertEqual(sumprod([0.1] * 20, [True, False] * 10), 1.0)
1295+
self.assertEqual(sumprod([1.0, 10E100, 1.0, -10E100], [1.0]*4), 2.0)
1296+
1297+
def test_sumprod_stress(self):
1298+
sumprod = math.sumprod
1299+
product = itertools.product
1300+
Decimal = decimal.Decimal
1301+
Fraction = fractions.Fraction
1302+
1303+
class Int(int):
1304+
def __add__(self, other):
1305+
return Int(int(self) + int(other))
1306+
def __mul__(self, other):
1307+
return Int(int(self) * int(other))
1308+
__radd__ = __add__
1309+
__rmul__ = __mul__
1310+
def __repr__(self):
1311+
return f'Int({int(self)})'
1312+
1313+
class Flt(float):
1314+
def __add__(self, other):
1315+
return Int(int(self) + int(other))
1316+
def __mul__(self, other):
1317+
return Int(int(self) * int(other))
1318+
__radd__ = __add__
1319+
__rmul__ = __mul__
1320+
def __repr__(self):
1321+
return f'Flt({int(self)})'
1322+
1323+
def baseline_sumprod(p, q):
1324+
"""This defines the target behavior including expections and special values.
1325+
However, it is subject to rounding errors, so float inputs should be exactly
1326+
representable with only a few bits.
1327+
"""
1328+
total = 0
1329+
for p_i, q_i in zip(p, q, strict=True):
1330+
total += p_i * q_i
1331+
return total
1332+
1333+
def run(func, *args):
1334+
"Make comparing functions easier. Returns error status, type, and result."
1335+
try:
1336+
result = func(*args)
1337+
except (AssertionError, NameError):
1338+
raise
1339+
except Exception as e:
1340+
return type(e), None, 'None'
1341+
return None, type(result), repr(result)
1342+
1343+
pools = [
1344+
(-5, 10, -2**20, 2**31, 2**40, 2**61, 2**62, 2**80, 1.5, Int(7)),
1345+
(5.25, -3.5, 4.75, 11.25, 400.5, 0.046875, 0.25, -1.0, -0.078125),
1346+
(-19.0*2**500, 11*2**1000, -3*2**1500, 17*2*333,
1347+
5.25, -3.25, -3.0*2**(-333), 3, 2**513),
1348+
(3.75, 2.5, -1.5, float('inf'), -float('inf'), float('NaN'), 14,
1349+
9, 3+4j, Flt(13), 0.0),
1350+
(13.25, -4.25, Decimal('10.5'), Decimal('-2.25'), Fraction(13, 8),
1351+
Fraction(-11, 16), 4.75 + 0.125j, 97, -41, Int(3)),
1352+
(Decimal('6.125'), Decimal('12.375'), Decimal('-2.75'), Decimal(0),
1353+
Decimal('Inf'), -Decimal('Inf'), Decimal('NaN'), 12, 13.5),
1354+
(-2.0 ** -1000, 11*2**1000, 3, 7, -37*2**32, -2*2**-537, -2*2**-538,
1355+
2*2**-513),
1356+
(-7 * 2.0 ** -510, 5 * 2.0 ** -520, 17, -19.0, -6.25),
1357+
(11.25, -3.75, -0.625, 23.375, True, False, 7, Int(5)),
1358+
]
1359+
1360+
for pool in pools:
1361+
for size in range(4):
1362+
for args1 in product(pool, repeat=size):
1363+
for args2 in product(pool, repeat=size):
1364+
args = (args1, args2)
1365+
self.assertEqual(
1366+
run(baseline_sumprod, *args),
1367+
run(sumprod, *args),
1368+
args,
1369+
)
1370+
12051371
def testModf(self):
12061372
self.assertRaises(TypeError, math.modf)
12071373

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add math.sumprod() to compute the sum of products.

Modules/clinic/mathmodule.c.h

Lines changed: 38 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)
0