8000 MNT Replace pytest.warns(None) in linear_models tests (#23138) · scikit-learn/scikit-learn@3d09637 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3d09637

Browse files
authored
MNT Replace pytest.warns(None) in linear_models tests (#23138)
1 parent a739f6c commit 3d09637

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

sklearn/linear_model/tests/test_base.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,14 +184,15 @@ def test_deprecate_normalize(normalize, default):
184184
expected = None
185185
warning_msg = []
186186

187-
with pytest.warns(expected) as record:
188-
_normalize = _deprecate_normalize(normalize, default, "estimator")
189-
assert _normalize == output
190-
191-
n_warnings = 0 if expected is None else 1
192-
assert len(record) == n_warnings
193-
if n_warnings:
187+
if expected is None:
188+
with warnings.catch_warnings():
189+
warnings.simplefilter("error", FutureWarning)
190+
_normalize = _deprecate_normalize(normalize, default, "estimator")
191+
else:
192+
with pytest.warns(expected) as record:
193+
_normalize = _deprecate_normalize(normalize, default, "estimator")
194194
assert all([warning in str(record[0].message) for warning in warning_msg])
195+
assert _normalize == output
195196

196197

197198
def test_linear_regression_sparse(random_state=0):

sklearn/linear_model/tests/test_common.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
import sys
8+
import warnings
89
import numpy as np
910

1011
from sklearn.base import is_classifier
@@ -56,6 +57,12 @@ def test_linear_model_normalize_deprecation_message(
5657
y = np.sign(y)
5758

5859
model = estimator(normalize=normalize)
60+
if warning_category is None:
61+
with warnings.catch_warnings():
62+
warnings.simplefilter("error", FutureWarning)
63+
model.fit(X, y)
64+
return
65+
5966
with pytest.warns(warning_category) as record:
6067
model.fit(X, y)
6168
# Filter record in case other unrelated warnings are raised
@@ -67,6 +74,5 @@ def test_linear_model_normalize_deprecation_message(
6774
msg += "\n"
6875
raise AssertionError(msg)
6976
wanted = [r for r in record if r.category == warning_category]
70-
if warning_category is not None:
71-
assert "'normalize' was deprecated" in str(wanted[0].message)
77+
assert "'normalize' was deprecated" in str(wanted[0].message)
7278
assert len(wanted) == n_warnings

0 commit comments

Comments
 (0)
0