2
2
3
3
import numpy as np
4
4
import pytest
5
+ import warnings
5
6
6
7
from scipy .sparse import issparse
7
8
from sklearn .semi_supervised import _label_propagation as label_propagation
@@ -151,14 +152,14 @@ def test_convergence_warning():
151
152
assert mdl .n_iter_ == mdl .max_iter
152
153
153
154
mdl = label_propagation .LabelSpreading (kernel = "rbf" , max_iter = 500 )
154
- with pytest .warns (None ) as record :
155
+ with warnings .catch_warnings ():
156
+ warnings .simplefilter ("error" , ConvergenceWarning )
155
157
mdl .fit (X , y )
156
- assert not [w .message for w in record ]
157
158
158
159
mdl = label_propagation .LabelPropagation (kernel = "rbf" , max_iter = 500 )
159
- with pytest .warns (None ) as record :
160
+ with warnings .catch_warnings ():
161
+ warnings .simplefilter ("error" , ConvergenceWarning )
160
162
mdl .fit (X , y )
161
- assert not [w .message for w in record ]
162
163
163
164
164
165
@pytest .mark .parametrize (
@@ -173,9 +174,9 @@ def test_label_propagation_non_zero_normalizer(LabelPropagationCls):
173
174
X = np .array ([[100.0 , 100.0 ], [100.0 , 100.0 ], [0.0 , 0.0 ], [0.0 , 0.0 ]])
174
175
y = np .array ([0 , 1 , - 1 , - 1 ])
175
176
mdl = LabelPropagationCls (kernel = "knn" , max_iter = 100 , n_neighbors = 1 )
176
- with pytest .warns (None ) as record :
177
+ with warnings .catch_warnings ():
178
+ warnings .simplefilter ("error" , RuntimeWarning )
177
179
mdl .fit (X , y )
178
- assert not [w .message for w in record ]
179
180
180
181
181
182
def test_predict_sparse_callable_kernel ():
0 commit comments