8000 TST replace assert_warns* by pytest.warns in model_selection/tests (#… · scikit-learn/scikit-learn@b5e55f7 · GitHub
[go: up one dir, main page]

Skip to content

Commit b5e55f7

Browse files
authored
TST replace assert_warns* by pytest.warns in model_selection/tests (#19458)
1 parent 43241b1 commit b5e55f7

File tree

4 files changed

+43
-25
lines changed

4 files changed

+43
-25
lines changed

sklearn/model_selection/_validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1473,7 +1473,7 @@ def _translate_train_sizes(train_sizes, n_max_training_samples):
14731473
if n_ticks > train_sizes_abs.shape[0]:
14741474
warnings.warn("Removed duplicate entries from 'train_sizes'. Number "
14751475
"of ticks will be less than the size of "
1476-
"'train_sizes' %d instead of %d)."
1476+
"'train_sizes': %d instead of %d."
14771477
% (train_sizes_abs.shape[0], n_ticks), RuntimeWarning)
14781478

14791479
return train_sizes_abs

sklearn/model_selection/tests/test_search.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
import pytest
1515

1616
from sklearn.utils._testing import (
17-
assert_warns,
18-
assert_warns_message,
1917
assert_raise_message,
2018
assert_array_equal,
2119
assert_array_almost_equal,
@@ -1433,7 +1431,12 @@ def test_grid_search_failing_classifier():
14331431
# error in this test.
14341432
gs = GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy',
14351433
refit=False, error_score=0.0)
1436-
assert_warns(FitFailedWarning, gs.fit, X, y)
1434+
warning_message = (
1435+
"Estimator fit failed. The score on this train-test partition "
1436+
"for these parameters will be set to 0.0.*."
1437+
)
1438+
with pytest.warns(FitFailedWarning, match=warning_message):
1439+
gs.fit(X, y)
14371440
n_candidates = l 8000 en(gs.cv_results_['params'])
14381441

14391442
# Ensure that grid scores were set to zero as required for those fits
@@ -1449,7 +1452,12 @@ def get_cand_scores(i):
14491452

14501453
gs = GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy',
14511454
refit=False, error_score=float('nan'))
1452-
assert_warns(FitFailedWarning, gs.fit, X, y)
1455+
warning_message = (
1456+
"Estimator fit failed. The score on this train-test partition "
1457+
"for these parameters will be set to nan."
1458+
)
1459+
with pytest.warns(FitFailedWarning, match=warning_message):
1460+
gs.fit(X, y)
14531461
n_candidates = len(gs.cv_results_['params'])
14541462
assert all(np.all(np.isnan(get_cand_scores(cand_i)))
14551463
for cand_i in range(n_candidates)
@@ -1492,8 +1500,8 @@ def test_parameters_sampler_replacement():
14921500
'than n_iter=%d. Running %d iterations. For '
< 8000 div aria-hidden="true" style="left:-2px" class="position-absolute top-0 d-flex user-select-none DiffLineTableCellParts-module__in-progress-comment-indicator--hx3m3">
14931501
'exhaustive searches, use GridSearchCV.'
14941502
% (grid_size, n_iter, grid_size))
1495-
assert_warns_message(UserWarning, expected_warning,
1496-
list, sampler)
1503+
with pytest.warns(UserWarning, match=expected_warning):
1504+
list(sampler)
14971505

14981506
# degenerates to GridSearchCV if n_iter the same as grid_size
14991507
sampler = ParameterSampler(params, n_iter=8)

sklearn/model_selection/tests/test_split.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from sklearn.utils._testing import assert_raises_regexp
1515
from sklearn.utils._testing import assert_array_almost_equal
1616
from sklearn.utils._testing import assert_array_equal
17-
from sklearn.utils._testing import assert_warns_message
1817
from sklearn.utils._testing import assert_raise_message
1918
from sklearn.utils._testing import ignore_warnings
2019
from sklearn.utils.validation import _num_samples
@@ -193,8 +192,8 @@ def test_kfold_valueerrors():
193192
y = np.array([3, 3, -1, -1, 3])
194193

195194
skf_3 = StratifiedKFold(3)
196-
assert_warns_message(Warning, "The least populated class",
197-
next, skf_3.split(X2, y))
195+
with pytest.warns(Warning, match="The least populated class"):
196+
next(skf_3.split(X2, y))
198197

199198
# Check that despite the warning the folds are still computed even
200199
# though all the classes are not necessarily represented at on each

sklearn/model_selection/tests/test_validation.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from sklearn.utils._testing import assert_almost_equal
1818
from sklearn.utils._testing import assert_raises
1919
from sklearn.utils._testing import assert_raise_message
20-
from sklearn.utils._testing import assert_warns
21-
from sklearn.utils._testing import assert_warns_message
2220
from sklearn.utils._testing import assert_raises_regex
2321
from sklearn.utils._testing import assert_array_almost_equal
2422
from sklearn.utils._testing import assert_array_equal
@@ -857,13 +855,12 @@ def split(self, X, y=None, groups=None):
857855

858856
X, y = load_iris(return_X_y=True)
859857

860-
warning_message = ('Number of classes in training fold (2) does '
861-
'not match total number of classes (3). '
858+
warning_message = (r'Number of classes in training fold \(2\) does '
859+
r'not match total number of classes \(3\). '
862860
'Results may not be appropriate for your use case.')
863-
assert_warns_message(RuntimeWarning, warning_message,
864-
cross_val_predict,
865-
LogisticRegression(solver="liblinear"),
866-
X, y, method='predict_proba', cv=KFold(2))
861+
with pytest.warns(RuntimeWarning, match=warning_message):
862+
cross_val_predict(LogisticRegression(solver="liblinear"),
863+
X, y, method='predict_proba', cv=KFold(2))
867864

868865

869866
def test_cross_val_predict_decision_function_shape():
@@ -1210,9 +1207,13 @@ def test_learning_curve_remove_duplicate_sample_sizes():
12101207
n_redundant=0, n_classes=2,
12111208
n_clusters_per_class=1, random_state=0)
12121209
estimator = MockImprovingEstimator(2)
1213-
train_sizes, _, _ = assert_warns(
1214-
RuntimeWarning, learning_curve, estimator, X, y, cv=3,
1215-
train_sizes=np.linspace(0.33, 1.0, 3))
1210+
warning_message = (
1211+
"Removed duplicate entries from 'train_sizes'. Number of ticks "
1212+
"will be less than the size of 'train_sizes': 2 instead of 3."
1213+
)
1214+
with pytest.warns(RuntimeWarning, match=warning_message):
1215+
train_sizes, _, _ = learning_curve(
1216+
estimator, X, y, cv=3, train_sizes=np.linspace(0.33, 1.0, 3))
12161217
assert_array_equal(train_sizes, [1, 2])
12171218

12181219

@@ -1753,8 +1754,13 @@ def test_fit_and_score_failing():
17531754
# passing error score to trigger the warning message
17541755
fit_and_score_kwargs = {'error_score': 0}
17551756
# check if the warning message type is as expected
1756-
assert_warns(FitFailedWarning, _fit_and_score, *fit_and_score_args,
1757-
**fit_and_score_kwargs)
1757+
warning_message = (
1758+
"Estimator fit failed. The score on this train-test partition for "
1759+
"these parameters will be set to %f."
1760+
% (fit_and_score_kwargs['error_score'])
1761+
)
1762+
with pytest.warns(FitFailedWarning, match=warning_message):
1763+
_fit_and_score(*fit_and_score_args, **fit_and_score_kwargs)
17581764
# since we're using FailingClassfier, our error will be the following
17591765
error_message = "ValueError: Failing classifier failed as required"
17601766
# the warning message we're expecting to see
@@ -1769,8 +1775,13 @@ def test_warn_trace(msg):
17691775
mtb = split[0] + '\n' + split[-1]
17701776
return warning_message in mtb
17711777
# check traceback is included
1772-
assert_warns_message(FitFailedWarning, test_warn_trace, _fit_and_score,
1773-
*fit_and_score_args, **fit_and_score_kwargs)
1778+
warning_message = (
1779+
"Estimator fit failed. The score on this train-test partition for "
1780+
"these parameters will be set to %f."
1781+
% (fit_and_score_kwargs['error_score'])
1782+
)
1783+
with pytest.warns(FitFailedWarning, match=warning_message):
1784+
_fit_and_score(*fit_and_score_args, **fit_and_score_kwargs)
17741785

17751786
fit_and_score_kwargs = {'error_score': 'raise'}
17761787
# check if exception was raised, with default error_score='raise'

0 commit comments

Comments
 (0)
0