File tree 2 files changed +16
-9
lines changed
sklearn/linear_model/tests 2 files changed +16
-9
lines changed Original file line number Diff line number Diff line change @@ -184,14 +184,15 @@ def test_deprecate_normalize(normalize, default):
184
184
expected = None
185
185
warning_msg = []
186
186
187
- with pytest . warns ( expected ) as record :
188
- _normalize = _deprecate_normalize ( normalize , default , "estimator" )
189
- assert _normalize == output
190
-
191
- n_warnings = 0 if expected is None else 1
192
- assert len ( record ) == n_warnings
193
- if n_warnings :
187
+ if expected is None :
188
+ with warnings . catch_warnings ():
189
+ warnings . simplefilter ( "error" , FutureWarning )
190
+ _normalize = _deprecate_normalize ( normalize , default , "estimator" )
191
+ else :
192
+ with pytest . warns ( expected ) as record :
193
+ _normalize = _deprecate_normalize ( normalize , default , "estimator" )
194
194
assert all ([warning in str (record [0 ].message ) for warning in warning_msg ])
195
+ assert _normalize == output
195
196
196
197
197
198
def test_linear_regression_sparse (random_state = 0 ):
Original file line number Diff line number Diff line change 5
5
import pytest
6
6
7
7
import sys
8
+ import warnings
8
9
import numpy as np
9
10
10
11
from sklearn .base import is_classifier
@@ -56,6 +57,12 @@ def test_linear_model_normalize_deprecation_message(
56
57
y = np .sign (y )
57
58
58
59
model = estimator (normalize = normalize )
60
+ if warning_category is None :
61
+ with warnings .catch_warnings ():
62
+ warnings .simplefilter ("error" , FutureWarning )
63
+ model .fit (X , y )
64
+ return
65
+
59
66
with pytest .warns (warning_category ) as record :
60
67
model .fit (X , y )
61
68
# Filter record in case other unrelated warnings are raised
@@ -67,6 +74,5 @@ def test_linear_model_normalize_deprecation_message(
67
74
msg += "\n "
68
75
raise AssertionError (msg )
69
76
wanted = [r for r in record if r .category == warning_category ]
70
- if warning_category is not None :
71
- assert "'normalize' was deprecated" in str (wanted [0 ].message )
77
+ assert "'normalize' was deprecated" in str (wanted [0 ].message )
72
78
assert len (wanted ) == n_warnings
You can’t perform that action at this time.
0 commit comments