|
4 | 4 | from test.support import verbose, requires_IEEE_754
|
5 | 5 | from test import support
|
6 | 6 | import unittest
|
| 7 | +import fractions |
7 | 8 | import itertools
|
8 | 9 | import decimal
|
9 | 10 | import math
|
@@ -1202,6 +1203,171 @@ def testLog10(self):
|
1202 | 1203 | self.assertEqual(math.log(INF), INF)
|
1203 | 1204 | self.assertTrue(math.isnan(math.log10(NAN)))
|
1204 | 1205 |
|
| 1206 | + def testSumProd(self): |
| 1207 | + sumprod = math.sumprod |
| 1208 | + Decimal = decimal.Decimal |
| 1209 | + Fraction = fractions.Fraction |
| 1210 | + |
| 1211 | + # Core functionality |
| 1212 | + self.assertEqual(sumprod(iter([10, 20, 30]), (1, 2, 3)), 140) |
| 1213 | + self.assertEqual(sumprod([1.5, 2.5], [3.5, 4.5]), 16.5) |
| 1214 | + self.assertEqual(sumprod([], []), 0) |
| 1215 | + |
| 1216 | + # Type preservation and coercion |
| 1217 | + for v in [ |
| 1218 | + (10, 20, 30), |
| 1219 | + (1.5, -2.5), |
| 1220 | + (Fraction(3, 5), Fraction(4, 5)), |
| 1221 | + (Decimal(3.5), Decimal(4.5)), |
| 1222 | + (2.5, 10), # float/int |
| 1223 | + (2.5, Fraction(3, 5)), # float/fraction |
| 1224 | + (25, Fraction(3, 5)), # int/fraction |
| 1225 | + (25, Decimal(4.5)), # int/decimal |
| 1226 | + ]: |
| 1227 | + for p, q in [(v, v), (v, v[::-1])]: |
| 1228 | + with self.subTest(p=p, q=q): |
| 1229 | + expected = sum(p_i * q_i for p_i, q_i in zip(p, q, strict=True)) |
| 1230 | + actual = sumprod(p, q) |
| 1231 | + self.assertEqual(expected, actual) |
| 1232 | + self.assertEqual(type(expected), type(actual)) |
| 1233 | + |
| 1234 | + # Bad arguments |
| 1235 | + self.assertRaises(TypeError, sumprod) # No args |
| 1236 | + self.assertRaises(TypeError, sumprod, []) # One arg |
| 1237 | + self.assertRaises(TypeError, sumprod, [], [], []) # Three args |
| 1238 | + self.assertRaises(TypeError, sumprod, None, [10]) # Non-iterable |
| 1239 | + self.assertRaises(TypeError, sumprod, [10], None) # Non-iterable |
| 1240 | + |
| 1241 | + # Uneven lengths |
| 1242 | + self.assertRaises(ValueError, sumprod, [10, 20], [30]) |
| 1243 | + self.assertRaises(ValueError, sumprod, [10], [20, 30]) |
| 1244 | + |
| 1245 | + # Error in iterator |
| 1246 | + def raise_after(n): |
| 1247 | + for i in range(n): |
| 1248 | + yield i |
| 1249 | + raise RuntimeError |
| 1250 | + with self.assertRaises(RuntimeError): |
| 1251 | + sumprod(range(10), raise_after(5)) |
| 1252 | + with self.assertRaises(RuntimeError): |
| 1253 | + sumprod(raise_after(5), range(10)) |
| 1254 | + |
| 1255 | + # Error in multiplication |
| 1256 | + class BadMultiply: |
| 1257 | + def __mul__(self, other): |
| 1258 | + raise RuntimeError |
| 1259 | + def __rmul__(self, other): |
| 1260 | + raise RuntimeError |
| 1261 | + with self.assertRaises(RuntimeError): |
| 1262 | + sumprod([10, BadMultiply(), 30], [1, 2, 3]) |
| 1263 | + with self.assertRaises(RuntimeError): |
| 1264 | + sumprod([1, 2, 3], [10, BadMultiply(), 30]) |
| 1265 | + |
| 1266 | + # Error in addition |
| 1267 | + with self.assertRaises(TypeError): |
| 1268 | + sumprod(['abc', 3], [5, 10]) |
| 1269 | + with self.assertRaises(TypeError): |
| 1270 | + sumprod([5, 10], ['abc', 3]) |
| 1271 | + |
| 1272 | + # Special values should give the same as the pure python recipe |
| 1273 | + self.assertEqual(sumprod([10.1, math.inf], [20.2, 30.3]), math.inf) |
| 1274 | + self.assertEqual(sumprod([10.1, math.inf], [math.inf, 30.3]), math.inf) |
| 1275 | + self.assertEqual(sumprod([10.1, math.inf], [math.inf, math.inf]), math.inf) |
| 1276 | + self.assertEqual(sumprod([10.1, -math.inf], [20.2, 30.3]), -math.inf) |
| 1277 | + self.assertTrue(math.isnan(sumprod([10.1, math.inf], [-math.inf, math.inf]))) |
| 1278 | + self.assertTrue(math.isnan(sumprod([10.1, math.nan], [20.2, 30.3]))) |
| 1279 | + self.assertTrue(math.isnan(sumprod([10.1, math.inf], [math.nan, 30.3]))) |
| 1280 | + self.assertTrue(math.isnan(sumprod([10.1, math.inf], [20.3, math.nan]))) |
| 1281 | + |
| 1282 | + # Error cases that arose during development |
| 1283 | + args = ((-5, -5, 10), (1.5, 4611686018427387904, 2305843009213693952)) |
| 1284 | + self.assertEqual(sumprod(*args), 0.0) |
| 1285 | + |
| 1286 | + |
| 1287 | + @requires_IEEE_754 |
| 1288 | + @unittest.skipIf(HAVE_DOUBLE_ROUNDING, |
| 1289 | + "sumprod() accuracy not guaranteed on machines with double rounding") |
| 1290 | + @support.cpython_only # Other implementations may choose a different algorithm |
| 1291 | + def test_sumprod_accuracy(self): |
| 1292 | + sumprod = math.sumprod |
| 1293 | + self.assertEqual(sumprod([0.1] * 10, [1]*10), 1.0) |
| 1294 | + self.assertEqual(sumprod([0.1] * 20, [True, False] * 10), 1.0) |
| 1295 | + self.assertEqual(sumprod([1.0, 10E100, 1.0, -10E100], [1.0]*4), 2.0) |
| 1296 | + |
| 1297 | + def test_sumprod_stress(self): |
| 1298 | + sumprod = math.sumprod |
| 1299 | + product = itertools.product |
| 1300 | + Decimal = decimal.Decimal |
| 1301 | + Fraction = fractions.Fraction |
| 1302 | + |
| 1303 | + class Int(int): |
| 1304 | + def __add__(self, other): |
| 1305 | + return Int(int(self) + int(other)) |
| 1306 | + def __mul__(self, other): |
| 1307 | + return Int(int(self) * int(other)) |
| 1308 | + __radd__ = __add__ |
| 1309 | + __rmul__ = __mul__ |
| 1310 | + def __repr__(self): |
| 1311 | + return f'Int({int(self)})' |
| 1312 | + |
| 1313 | + class Flt(float): |
| 1314 | + def __add__(self, other): |
| 1315 | + return Int(int(self) + int(other)) |
| 1316 | + def __mul__(self, other): |
| 1317 | + return Int(int(self) * int(other)) |
| 1318 | + __radd__ = __add__ |
| 1319 | + __rmul__ = __mul__ |
| 1320 | + def __repr__(self): |
| 1321 | + return f'Flt({int(self)})' |
| 1322 | + |
| 1323 | + def baseline_sumprod(p, q): |
| 1324 | + """This defines the target behavior including expections and special values. |
| 1325 | + However, it is subject to rounding errors, so float inputs should be exactly |
| 1326 | + representable with only a few bits. |
| 1327 | + """ |
| 1328 | + total = 0 |
| 1329 | + for p_i, q_i in zip(p, q, strict=True): |
| 1330 | + total += p_i * q_i |
| 1331 | + return total |
| 1332 | + |
| 1333 | + def run(func, *args): |
| 1334 | + "Make comparing functions easier. Returns error status, type, and result." |
| 1335 | + try: |
| 1336 | + result = func(*args) |
| 1337 | + except (AssertionError, NameError): |
| 1338 | + raise |
| 1339 | + except Exception as e: |
| 1340 | + return type(e), None, 'None' |
| 1341 | + return None, type(result), repr(result) |
| 1342 | + |
| 1343 | + pools = [ |
| 1344 | + (-5, 10, -2**20, 2**31, 2**40, 2**61, 2**62, 2**80, 1.5, Int(7)), |
| 1345 | + (5.25, -3.5, 4.75, 11.25, 400.5, 0.046875, 0.25, -1.0, -0.078125), |
| 1346 | + (-19.0*2**500, 11*2**1000, -3*2**1500, 17*2*333, |
| 1347 | + 5.25, -3.25, -3.0*2**(-333), 3, 2**513), |
| 1348 | + (3.75, 2.5, -1.5, float('inf'), -float('inf'), float('NaN'), 14, |
| 1349 | + 9, 3+4j, Flt(13), 0.0), |
| 1350 | + (13.25, -4.25, Decimal('10.5'), Decimal('-2.25'), Fraction(13, 8), |
| 1351 | + Fraction(-11, 16), 4.75 + 0.125j, 97, -41, Int(3)), |
| 1352 | + (Decimal('6.125'), Decimal('12.375'), Decimal('-2.75'), Decimal(0), |
| 1353 | + Decimal('Inf'), -Decimal('Inf'), Decimal('NaN'), 12, 13.5), |
| 1354 | + (-2.0 ** -1000, 11*2**1000, 3, 7, -37*2**32, -2*2**-537, -2*2**-538, |
| 1355 | + 2*2**-513), |
| 1356 | + (-7 * 2.0 ** -510, 5 * 2.0 ** -520, 17, -19.0, -6.25), |
| 1357 | + (11.25, -3.75, -0.625, 23.375, True, False, 7, Int(5)), |
| 1358 | + ] |
| 1359 | + |
| 1360 | + for pool in pools: |
| 1361 | + for size in range(4): |
| 1362 | + for args1 in product(pool, repeat=size): |
| 1363 | + for args2 in product(pool, repeat=size): |
| 1364 | + args = (args1, args2) |
| 1365 | + self.assertEqual( |
| 1366 | + run(baseline_sumprod, *args), |
| 1367 | + run(sumprod, *args), |
| 1368 | + args, |
| 1369 | + ) |
| 1370 | + |
1205 | 1371 | def testModf(self):
|
1206 | 1372 | self.assertRaises(TypeError, math.modf)
|
1207 | 1373 |
|
|
0 commit comments