8000 MNT Add validation for parameters in `ElasticNet` (#22240) · scikit-learn/scikit-learn@d7fc1df · GitHub
[go: up one dir, main page]

Skip to content

Commit d7fc1df

Browse files
ArturoAmorQogrisel
andauthored
MNT Add validation for parameters in ElasticNet (#22240)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 8cf0959 commit d7fc1df

File tree

3 files changed

+57
-14
lines changed

3 files changed

+57
-14
lines changed

doc/whats_new/v1.1.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,11 @@ Changelog
365365
warning when `l1_ratio=0`.
366366
:pr:`21724` by :user:`Yar Khine Phyo <yarkhinephyo>`.
367367

368+
- |Enhancement| :class:`linear_model.ElasticNet` and :class:`linear_model.Lasso`
369+
now raise consistent error messages when passed invalid values for `l1_ratio`,
370+
`alpha`, `max_iter` and `tol`.
371+
:pr:`22240` by :user:`Arturo Amor <ArturoAmorQ>`.
372+
368373
:mod:`sklearn.metrics`
369374
......................
370375

sklearn/linear_model/_coordinate_descent.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ..base import RegressorMixin, MultiOutputMixin
1919
from ._base import _preprocess_data, _deprecate_normalize
2020
from ..utils import check_array
21+
from ..utils import check_scalar
2122
from ..utils.validation import check_random_state
2223
from ..model_selection import check_cv
2324
from ..utils.extmath import safe_sparse_dot
@@ -903,6 +904,13 @@ def fit(self, X, y, sample_weight=None, check_input=True):
903904
self.normalize, default=False, estimator_name=self.__class__.__name__
904905
)
905906

907+
check_scalar(
908+
self.alpha,
909+
"alpha",
910+
target_type=numbers.Real,
911+
min_val=0.0,
912+
)
913+
906914
if self.alpha == 0:
907915
warnings.warn(
908916
"With alpha=0, this algorithm does not converge "
@@ -917,15 +925,21 @@ def fit(self, X, y, sample_weight=None, check_input=True):
917925
% self.precompute
918926
)
919927

920-
if (
921-
not isinstance(self.l1_ratio, numbers.Number)
922-
or self.l1_ratio < 0
923-
or self.l1_ratio > 1
924-
):
925-
raise ValueError(
926-
f"l1_ratio must be between 0 and 1; got l1_ratio={self.l1_ratio}"
928+
check_scalar(
929+
self.l1_ratio,
930+
"l1_ratio",
931+
target_type=numbers.Real,
932+
min_val=0.0,
933+
max_val=1.0,
934+
)
935+
936+
if self.max_iter is not None:
937+
check_scalar(
938+
self.max_iter, "max_iter", target_type=numbers.Integral, min_val=1
927939
)
928940

941+
check_scalar(self.tol, "tol", target_type=numbers.Real, min_val=0.0)
942+
929943
# Remember if X is copied
930944
X_copied = False
931945
# We expect X and y to be float64 or float32 Fortran ordered arrays

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,17 +106,41 @@ def test_assure_warning_when_normalize(CoordinateDescentModel, normalize, n_warn
106106
assert len(record) == n_warnings
107107

108108

109-
@pytest.mark.parametrize("l1_ratio", (-1, 2, None, 10, "something_wrong"))
110-
def test_l1_ratio_param_invalid(l1_ratio):
109+
@pytest.mark.parametrize(
110+
"params, err_type, err_msg",
111+
[
112+
({"alpha": -1}, ValueError, "alpha == -1, must be >= 0.0"),
113+
({"l1_ratio": -1}, ValueError, "l1_ratio == -1, must be >= 0.0"),
114+
({"l1_ratio": 2}, ValueError, "l1_ratio == 2, must be <= 1.0"),
115+
(
116+
{"l1_ratio": "1"},
117+
TypeError,
118+
"l1_ratio must be an instance of <class 'numbers.Real'>, not <class 'str'>",
119+
),
120+
({"tol": -1.0}, ValueError, "tol == -1.0, must be >= 0."),
121+
(
122+
{"tol": "1"},
123+
TypeError,
124+
"tol must be an instance of <class 'numbers.Real'>, not <class 'str'>",
125+
),
126+
({"max_iter": 0}, ValueError, "max_iter == 0, must be >= 1."),
127+
(
128+
{ 6D40 "max_iter": "1"},
129+
TypeError,
130+
"max_iter must be an instance of <class 'numbers.Integral'>, not <class"
131+
" 'str'>",
132+
),
133+
],
134+
)
135+
def test_param_invalid(params, err_type, err_msg):
111136
# Check that correct error is raised when l1_ratio in ElasticNet
112137
# is outside the correct range
113138
X = np.array([[-1.0], [0.0], [1.0]])
114-
Y = [-1, 0, 1] # just a straight line
139+
y = [-1, 0, 1] # just a straight line
115140

116-
msg = "l1_ratio must be between 0 and 1; got l1_ratio="
117-
clf = ElasticNet(alpha=0.1, l1_ratio=l1_ratio)
118-
with pytest.raises(ValueError, match=msg):
119-
clf.fit(X, Y)
141+
enet = ElasticNet(**params)
142+
with pytest.raises(err_type, match=err_msg):
143+
enet.fit(X, y)
120144

121145

122146
@pytest.mark.parametrize("order", ["C", "F"])

0 commit comments

Comments
 (0)
0