10000 [MRG + 1] FIX use high precision cumsum and check it is stable enough… · scikit-learn/scikit-learn@49d126f · GitHub
[go: up one dir, main page]

Skip to content

Commit 49d126f

Browse files
jnothmanlesteve
authored andcommitted
[MRG + 1] FIX use high precision cumsum and check it is stable enough (#7331)
* FIX use high precision cumsum and check it is stable enough
1 parent 2bf96df commit 49d126f

File tree

3 files changed

+38
-2
lines changed

3 files changed

+38
-2
lines changed

sklearn/metrics/ranking.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ..utils import check_consistent_length
2828
from ..utils import column_or_1d, check_array
2929
from ..utils.multiclass import type_of_target
30+
from ..utils.extmath import stable_cumsum
3031
from ..utils.fixes import isclose
3132
from ..utils.fixes import bincount
3233
from ..utils.fixes import array_equal
@@ -337,9 +338,9 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):
337338
threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
338339

339340
# accumulate the true positives with decreasing threshold
340-
tps = (y_true * weight).cumsum()[threshold_idxs]
341+
tps = stable_cumsum(y_true * weight)[threshold_idxs]
341342
if sample_weight is not None:
342-
fps = weight.cumsum()[threshold_idxs] - tps
343+
fps = stable_cumsum(weight)[threshold_idxs] - tps
343344
else:
344345
fps = 1 + threshold_idxs - tps
345346
return fps, tps, y_score[threshold_idxs]

sklearn/utils/extmath.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,3 +842,23 @@ def _deterministic_vector_sign_flip(u):
842842
signs = np.sign(u[range(u.shape[0]), max_abs_rows])
843843
u *= signs[:, np.newaxis]
844844
return u
845+
846+
847+
def stable_cumsum(arr, rtol=1e-05, atol=1e-08):
848+
"""Use high precision for cumsum and check that final value matches sum
849+
850+
Parameters
851+
----------
852+
arr : array-like
853+
To be cumulatively summed as flat
854+
rtol : float
855+
Relative tolerance, see ``np.allclose``
856+
atol : float
857+
Absolute tolerance, see ``np.allclose``
858+
"""
859+
out = np.cumsum(arr, dtype=np.float64)
860+
expected = np.sum(arr, dtype=np.float64)
861+
if not np.allclose(out[-1], expected, rtol=rtol, atol=atol):
862+
raise RuntimeError('cumsum was found to be unstable: '
863+
'its last element does not correspond to sum')
864+
return out

sklearn/utils/tests/test_extmath.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
from sklearn.utils.testing import assert_false
1818
from sklearn.utils.testing import assert_greater
1919
from sklearn.utils.testing import assert_raises
20+
from sklearn.utils.testing import assert_raise_message
2021
from sklearn.utils.testing import skip_if_32bit
22+
from sklearn.utils.testing import SkipTest
23+
from sklearn.utils.fixes import np_version
2124

2225
from sklearn.utils.extmath import density
2326
from sklearn.utils.extmath import logsumexp
@@ -32,6 +35,7 @@
3235
from sklearn.utils.extmath import _incremental_mean_and_var
3336
from sklearn.utils.extmath import _deterministic_vector_sign_flip
3437
from sklearn.utils.extmath import softmax
38+
from sklearn.utils.extmath import stable_cumsum
3539
from sklearn.datasets.samples_generator import make_low_rank_matrix
3640

3741

@@ -643,3 +647,14 @@ def test_softmax():
643647
exp_X = np.exp(X)
644648
sum_exp_X = np.sum(exp_X, axis=1).reshape((-1, 1))
645649
assert_array_almost_equal(softmax(X), exp_X / sum_exp_X)
650+
651+
652+
def test_stable_cumsum():
653+
if np_version < (1, 9):
654+
raise SkipTest("Sum is as unstable as cumsum for numpy < 1.9")
655+
assert_array_equal(stable_cumsum([1, 2, 3]), np.cumsum([1, 2, 3]))
656+
r = np.random.RandomState(0).rand(100000)
657+
assert_raise_message(RuntimeError,
658+
'cumsum was found to be unstable: its last element '
659+
'does not correspond to sum',
660+
stable_cumsum, r, rtol=0, atol=0)

0 commit comments

Comments
 (0)
0