8000 MNT Postpone conversion of `RidgeCV`'s `alphas` out of `__init__` (#2… · thomasjpfan/scikit-learn@2cf3c55 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2cf3c55

Browse files
authored
MNT Postpone conversion of RidgeCV's alphas out of __init__ (scikit-learn#21506)
1 parent d128e79 commit 2cf3c55

File tree

3 files changed

+30
-4
lines changed

3 files changed

+30
-4
lines changed

sklearn/linear_model/_ridge.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,7 +1483,7 @@ def __init__(
14831483
is_clf=False,
14841484
alpha_per_target=False,
14851485
):
1486-
self.alphas = np.asarray(alphas)
1486+
self.alphas = alphas
14871487
self.fit_intercept = fit_intercept
14881488
self.normalize = normalize
14891489
self.scoring = scoring
@@ -1842,6 +1842,8 @@ def fit(self, X, y, sample_weight=None):
18421842
if sample_weight is not None:
18431843
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
18441844

1845+
self.alphas = np.asarray(self.alphas)
1846+
18451847
if np.any(self.alphas <= 0):
18461848
raise ValueError(
18471849
"alphas must be strictly positive. Got {} containing some "
@@ -1977,7 +1979,7 @@ def __init__(
19771979
store_cv_values=False,
19781980
alpha_per_target=False,
19791981
):
1980-
self.alphas = np.asarray(alphas)
1982+
self.alphas = alphas
19811983
self.fit_intercept = fit_intercept
19821984
self.normalize = normalize
19831985
self.scoring = scoring

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def test_lasso_cv_positive_constraint():
351351

352352
def _scale_alpha_inplace(estimator, n_samples):
353353
"""Rescale the parameter alpha from when the estimator is evoked with
354-
normalize set to True to when it is evoked in a Pipeline with normalize set
354+
normalize set to True as if it were evoked in a Pipeline with normalize set
355355
to False and with a StandardScaler.
356356
"""
357357
if ("alpha" not in estimator.get_params()) and (
@@ -360,7 +360,10 @@ def _scale_alpha_inplace(estimator, n_samples):
360360
return
361361

362362
if isinstance(estimator, (RidgeCV, RidgeClassifierCV)):
363-
alphas = estimator.alphas * n_samples
363+
# alphas is not validated at this point and can be a list.
364+
# We convert it to a np.ndarray to make sure broadcasting
365+
# is used.
366+
alphas = np.asarray(estimator.alphas) * n_samples
364367
return estimator.set_params(alphas=alphas)
365368
if isinstance(estimator, (Lasso, LassoLars, MultiTaskLasso)):
366369
alpha = estimator.alpha * np.sqrt(n_samples)

sklearn/linear_model/tests/test_ridge.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,27 @@ def test_ridge_classifier_cv_store_cv_values(scoring):
11101110
assert r.cv_values_.shape == (n_samples, n_targets, n_alphas)
11111111

11121112

1113+
@pytest.mark.parametrize("Estimator", [RidgeCV, RidgeClassifierCV])
1114+
def test_ridgecv_alphas_conversion(Estimator):
1115+
rng = np.random.RandomState(0)
1116+
alphas = (0.1, 1.0, 10.0)
1117+
1118+
n_samples, n_features = 5, 5
1119+
if Estimator is RidgeCV:
1120+
y = rng.randn(n_samples)
1121+
else:
1122+
y = rng.randint(0, 2, n_samples)
1123+
X = rng.randn(n_samples, n_features)
1124+
1125+
ridge_est = Estimator(alphas=alphas)
1126+
assert (
1127+
ridge_est.alphas is alphas
1128+
), f"`alphas` was mutated in `{Estimator.__name__}.__init__`"
1129+
1130+
ridge_est.fit(X, y)
1131+
assert_array_equal(ridge_est.alphas, np.asarray(alphas))
1132+
1133+
11131134
def test_ridgecv_sample_weight():
11141135
rng = np.random.RandomState(0)
11151136
alphas = (0.1, 1.0, 10.0)

0 commit comments

Comments
 (0)
0