33
33
from sklearn .base import BaseEstimator
34
34
from sklearn .base import clone
35
35
from sklearn .exceptions import NotFittedError
36
+ from sklearn .exceptions import ConvergenceWarning
36
37
from sklearn .datasets import make_classification
37
38
from sklearn .datasets import make_blobs
38
39
from sklearn .datasets import make_multilabel_classification
@@ -344,11 +345,11 @@ def test_return_train_score_warn():
344
345
345
346
estimators = [
346
347
GridSearchCV (
347
- LinearSVC (random_state = 0 , max_iter = 100000 ),
348
+ LinearSVC (random_state = 0 ),
348
349
grid ,
349
350
iid = False ),
350
351
RandomizedSearchCV (
351
- LinearSVC (random_state = 0 , max_iter = 100000 ),
352
+ LinearSVC (random_state = 0 ),
352
353
grid ,
353
354
n_iter = 2 ,
354
355
iid = False )
@@ -358,7 +359,11 @@ def test_return_train_score_warn():
358
359
for estimator in estimators :
359
360
for val in [True , False , 'warn' ]:
360
361
estimator .set_params (return_train_score = val )
361
- result [val ] = assert_no_warnings (estimator .fit , X , y ).cv_results_
362
+ result [val ] = assert_no_warnings (
363
+ ignore_warnings (estimator .fit , category = ConvergenceWarning ),
364
+ X ,
365
+ y
366
+ ).cv_results_
362
367
363
368
train_keys = ['split0_train_score' , 'split1_train_score' ,
364
369
'split2_train_score' , 'mean_train_score' , 'std_train_score' ]
0 commit comments