8000 TST replace pytest.warns(None) in test_label_propagation.py (#23010) · scikit-learn/scikit-learn@408afcc · GitHub
[go: up one dir, main page]

Skip to content

Commit 408afcc

Browse files
Ben3940jeremiedbbthomasjpfan
authored
TST replace pytest.warns(None) in test_label_propagation.py (#23010)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent b4da3b4 commit 408afcc

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

sklearn/semi_supervised/tests/test_label_propagation.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import pytest
5+
import warnings
56

67
from scipy.sparse import issparse
78
from sklearn.semi_supervised import _label_propagation as label_propagation
@@ -151,14 +152,14 @@ def test_convergence_warning():
151152
assert mdl.n_iter_ == mdl.max_iter
152153

153154
mdl = label_propagation.LabelSpreading(kernel="rbf", max_iter=500)
154-
with pytest.warns(None) as record:
155+
with warnings.catch_warnings():
156+
warnings.simplefilter("error", ConvergenceWarning)
155157
mdl.fit(X, y)
156-
assert not [w.message for w in record]
157158

158159
mdl = label_propagation.LabelPropagation(kernel="rbf", max_iter=500)
159-
with pytest.warns(None) as record:
160+
with warnings.catch_warnings():
161+
warnings.simplefilter("error", ConvergenceWarning)
160162
mdl.fit(X, y)
161-
assert not [w.message for w in record]
162163

163164

164165
@pytest.mark.parametrize(
@@ -173,9 +174,9 @@ def test_label_propagation_non_zero_normalizer(LabelPropagationCls):
173174
X = np.array([[100.0, 100.0], [100.0, 100.0], [0.0, 0.0], [0.0, 0.0]])
174175
y = np.array([0, 1, -1, -1])
175176
mdl = LabelPropagationCls(kernel="knn", max_iter=100, n_neighbors=1)
176-
with pytest.warns(None) as record:
177+
with warnings.catch_warnings():
178+
warnings.simplefilter("error", RuntimeWarning)
177179
mdl.fit(X, y)
178-
assert not [w.message for w in record]
179180

180181

181182
def test_predict_sparse_callable_kernel():

0 commit comments

Comments
 (0)
0