8000 Replace ConvergenceWarning by RuntimeWarning (#7922) · maskani-moh/scikit-learn@a400838 · GitHub
[go: up one dir, main page]

Skip to content

Commit a400838

Browse files
lestevemaskani-moh
authored andcommitted
Replace ConvergenceWarning by RuntimeWarning (scikit-learn#7922)
when cumsum is unstable.
1 parent 046c364 commit a400838

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

sklearn/utils/extmath.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ..externals.six.moves import xrange
2626
from .sparsefuncs_fast import csr_row_norms
2727
from .validation import check_array
28-
from ..exceptions import ConvergenceWarning, NonBLASDotWarning
28+
from ..exceptions import NonBLASDotWarning
2929

3030

3131
def norm(x):
@@ -869,5 +869,5 @@ def stable_cumsum(arr, axis=None, rtol=1e-05, atol=1e-08):
869869
atol=atol, equal_nan=True)):
870870
warnings.warn('cumsum was found to be unstable: '
871871
'its last element does not correspond to sum',
872-
ConvergenceWarning)
872+
RuntimeWarning)
873873
return out

sklearn/utils/tests/test_extmath.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from sklearn.utils.extmath import _deterministic_vector_sign_flip
3737
from sklearn.utils.extmath import softmax
3838
from sklearn.utils.extmath import stable_cumsum
39-
from sklearn.exceptions import ConvergenceWarning
4039
from sklearn.datasets.samples_generator import make_low_rank_matrix
4140

4241

@@ -655,7 +654,7 @@ def test_stable_cumsum():
655654
raise SkipTest("Sum is as unstable as cumsum for numpy < 1.9")
656655
assert_array_equal(stable_cumsum([1, 2, 3]), np.cumsum([1, 2, 3]))
657656
r = np.random.RandomState(0).rand(100000)
658-
assert_warns(ConvergenceWarning, stable_cumsum, r, rtol=0, atol=0)
657+
assert_warns(RuntimeWarning, stable_cumsum, r, rtol=0, atol=0)
659658

660659
# test axis parameter
661660
A = np.random.RandomState(36).randint(1000, size=(5, 5, 5))

0 commit comments

Comments
 (0)
0