8000 Ignore convergence warning in test_return_train_score_warn. · AlexandreSev/scikit-learn@3eb22a0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3eb22a0

Browse files
committed
Ignore convergence warning in test_return_train_score_warn.
See scikit-learn#10866.
1 parent cfee9c1 commit 3eb22a0

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

sklearn/model_selection/tests/test_search.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from sklearn.base import BaseEstimator
3434
from sklearn.base import clone
3535
from sklearn.exceptions import NotFittedError
36+
from sklearn.exceptions import ConvergenceWarning
3637
from sklearn.datasets import make_classification
3738
from sklearn.datasets import make_blobs
3839
from sklearn.datasets import make_multilabel_classification
@@ -344,11 +345,11 @@ def test_return_train_score_warn():
344345

345346
estimators = [
346347
GridSearchCV(
347-
LinearSVC(random_state=0, max_iter=100000),
348+
LinearSVC(random_state=0),
348349
grid,
349350
iid=False),
350351
RandomizedSearchCV(
351-
LinearSVC(random_state=0, max_iter=100000),
352+
LinearSVC(random_state=0),
352353
grid,
353354
n_iter=2,
354355
iid=False)
@@ -358,7 +359,11 @@ def test_return_train_score_warn():
358359
for estimator in estimators:
359360
for val in [True, False, 'warn']:
360361
estimator.set_params(return_train_score=val)
361-
result[val] = assert_no_warnings(estimator.fit, X, y).cv_results_
362+
result[val] = assert_no_warnings(
363+
ignore_warnings(estimator.fit, category=ConvergenceWarning),
364+
X,
365+
y
366+
).cv_results_
362367

363368
train_keys = ['split0_train_score', 'split1_train_score',
364369
'split2_train_score', 'mean_train_score', 'std_train_score']

0 commit comments

Comments
 (0)
0