8000 MNT Add validation for parameter `alphas` in `RidgeCV` (#21606) · scikit-learn/scikit-learn@3e0f49b · GitHub
[go: up one dir, main page]

Skip to content

Commit 3e0f49b

Browse files
ArturoAmorQglemaitrejjerphan
authored
MNT Add validation for parameter alphas in RidgeCV (#21606)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
1 parent 678dbe7 commit 3e0f49b

File tree

3 files changed

+73
-19
lines changed

3 files changed

+73
-19
lines changed

doc/whats_new/v1.1.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,11 @@ Changelog
192192
multilabel classification.
193193
:pr:`19689` by :user:`Guillaume Lemaitre <glemaitre>`.
194194

195+
- |Enhancement| :class:`linear_model.RidgeCV` and
196+
:class:`linear_model.RidgeClassifierCV` now raise consistent error message
197+
when passed invalid values for `alphas`.
198+
:pr:`21606` by :user:`Arturo Amor <ArturoAmorQ>`.
199+
195200
- |Enhancement| :class:`linear_model.Ridge` and :class:`linear_model.RidgeClassifier`
196201
now raise consistent error message when passed invalid values for `alpha`,
197202
`max_iter` and `tol`.

sklearn/linear_model/_ridge.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111

1212
from abc import ABCMeta, abstractmethod
13+
from functools import partial
1314
import warnings
1415

1516
import numpy as np
@@ -1864,12 +1865,6 @@ def fit(self, X, y, sample_weight=None):
18641865

18651866
self.alphas = np.asarray(self.alphas)
18661867

1867-
if np.any(self.alphas <= 0):
1868-
raise ValueError(
1869-
"alphas must be strictly positive. Got {} containing some "
1870-
"negative or null value instead.".format(self.alphas)
1871-
)
1872-
18731868
X, y, X_offset, y_offset, X_scale = LinearModel._preprocess_data(
18741869
X,
18751870
y,
@@ -2038,9 +2033,30 @@ def fit(self, X, y, sample_weight=None):
20382033
the validation score.
20392034
"""
20402035
cv = self.cv
2036+
2037+
check_scalar_alpha = partial(
2038+
check_scalar,
2039+
target_type=numbers.Real,
2040+
min_val=0.0,
2041+
include_boundaries="neither",
2042+
)
2043+
2044+
if isinstance(self.alphas, (np.ndarray, list, tuple)):
2045+
n_alphas = 1 if np.ndim(self.alphas) == 0 else len(self.alphas)
2046+
if n_alphas != 1:
2047+
for index, alpha in enumerate(self.alphas):
2048+
alpha = check_scalar_alpha(alpha, f"alphas[{index}]")
2049+
else:
2050+
self.alphas[0] = check_scalar_alpha(self.alphas[0], "alphas")
2051+
else:
2052+
# check for single non-iterable item
2053+
self.alphas = check_scalar_alpha(self.alphas, "alphas")
2054+
2055+
alphas = np.asarray(self.alphas)
2056+
20412057
if cv is None:
20422058
estimator = _RidgeGCV(
2043-
self.alphas,
2059+
alphas,
20442060
fit_intercept=self.fit_intercept,
20452061
normalize=self.normalize,
20462062
scoring=self.scoring,
@@ -2059,7 +2075,8 @@ def fit(self, X, y, sample_weight=None):
20592075
raise ValueError("cv!=None and store_cv_values=True are incompatible")
20602076
if self.alpha_per_target:
20612077
raise ValueError("cv!=None and alpha_per_target=True are incompatible")
2062-
parameters = {"alpha": self.alphas}
2078+
2079+
parameters = {"alpha": alphas}
20632080
solver = "sparse_cg" if sparse.issparse(X) else "auto"
20642081
model = RidgeClassifier if is_classifier(self) else Ridge
20652082
gs = GridSearchCV(

sklearn/linear_model/tests/test_ridge.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,19 +1270,51 @@ def test_ridgecv_int_alphas():
12701270
ridge.fit(X, y)
12711271

12721272

1273-
def test_ridgecv_negative_alphas():
1274-
X = np.array([[-1.0, -1.0], [-1.0, 0], [-0.8, -1.0], [1.0, 1.0], [1.0, 0.0]])
1275-
y = [1, 1, 1, -1, -1]
1273+
@pytest.mark.parametrize("Estimator", [RidgeCV, RidgeClassifierCV])
1274+
@pytest.mark.parametrize(
1275+
"params, err_type, err_msg",
1276+
[
1277+
({"alphas": (1, -1, -100)}, ValueError, r"alphas\[1\] == -1, must be > 0.0"),
1278+
(
1279+
{"alphas": (-0.1, -1.0, -10.0)},
1280+
ValueError,
1281+
r"alphas\[0\] == -0.1, must be > 0.0",
1282+
),
1283+
(
1284+
{"alphas": (1, 1.0, "1")},
1285+
TypeError,
1286+
r"alphas\[2\] must be an instance of <class 'numbers.Real'>, not <class"
1287+
r" 'str'>",
1288+
),
1289+
],
1290+
)
1291+
def test_ridgecv_alphas_validation(Estimator, params, err_type, err_msg):
1292+
"""Check the `alphas` validation in RidgeCV and RidgeClassifierCV."""
12761293

1277-
# Negative integers
1278-
ridge = RidgeCV(alphas=(-1, -10, -100))
1279-
with pytest.raises(ValueError, match="alphas must be strictly positive"):
1280-
ridge.fit(X, y)
1294+
n_samples, n_features = 5, 5
1295+
X = rng.randn(n_samples, n_features)
1296+
y = rng.randint(0, 2, n_samples)
12811297

1282-
# Negative floats
1283-
ridge = RidgeCV(alphas=(-0.1, -1.0, -10.0))
1284-
with pytest.raises(ValueError, match="alphas must be strictly positive"):
1285-
ridge.fit(X, y)
1298+
with pytest.raises(err_type, match=err_msg):
1299+
Estimator(**params).fit(X, y)
1300+
1301+
1302+
@pytest.mark.parametrize("Estimator", [RidgeCV, RidgeClassifierCV])
1303+
def test_ridgecv_alphas_scalar(Estimator):
1304+
"""Check the case when `alphas` is a scalar.
1305+
This case was supported in the past when `alphas` where converted
1306+
into array in `__init__`.
1307+
We add this test to ensure backward compatibility.
1308+
"""
1309+
1310+
n_samples, n_features = 5, 5
1311+
X = rng.randn(n_samples, n_features)
1312+
if Estimator is RidgeCV:
1313+
y = rng.randn(n_samples)
1314+
else:
1315+
y = rng.randint(0, 2, n_samples)
1316+
1317+
Estimator(alphas=1).fit(X, y)
12861318

12871319

12881320
def test_raises_value_error_if_solver_not_supported():

0 commit comments

Comments
 (0)
0