10000 BUG: ElasticNectCV choosing improper l1_ratio · nullnotfound/scikit-learn@972e7cf · GitHub
[go: up one dir, main page]

Skip to content

Commit 972e7cf

Browse files
committed
BUG: ElasticNectCV choosing improper l1_ratio
The code was lacking good tests: it had only smoke tests. Shame on me (I am the author).
1 parent c87d45d commit 972e7cf

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

sklearn/linear_model/coordinate_descent.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,8 @@ def fit(self, X, y):
758758
Target values
759759
760760
"""
761+
# We avoid copying X so for to save memory. X will be copied
762+
# after the cross-validation loop
761763
X = atleast2d_or_csc(X, dtype=np.float64, order='F',
762764
copy=self.copy_X and self.fit_intercept)
763765
# From now on X can be touched inplace
@@ -776,6 +778,8 @@ def fit(self, X, y):
776778
l1_ratios = [1, ]
777779
path_params.pop('cv', None)
778780
path_params.pop('n_jobs', None)
781+
# We can modify X inplace
782+
path_params['copy_X'] = False
779783

780784
# Start to compute path on full data
781785
# XXX: is this really useful: we are fitting models that we won't
@@ -787,6 +791,11 @@ def fit(self, X, y):
787791
n_alphas = len(alphas)
788792
path_params.update({'alphas': alphas, 'n_alphas': n_alphas})
789793

794+
# If we are not computing in parallel, we don't want to modify X
795+
# inplace in the folds
796+
if self.n_jobs == 1 or self.n_jobs is None:
797+
path_params['copy_X'] = True
798+
790799
# init cross-validation generator
791800
cv = check_cv(self.cv, X)
792801

@@ -814,6 +823,7 @@ def fit(self, X, y):
814823
if this_best_mse < best_mse:
815824
model = models[i_best_alpha]
816825
best_l1_ratio = l1_ratio
826+
best_mse = this_best_mse
817827

818828
if hasattr(model, 'l1_ratio'):
819829
if model.l1_ratio != best_l1_ratio:

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,26 +177,40 @@ def test_lasso_path():
177177

178178

179179
def test_enet_path():
180-
X, y, X_test, y_test = build_dataset()
180+
# We use a large number of samples and of informative features so that
181+
# the l1_ratio selected is more toward ridge than lasso
182+
X, y, X_test, y_test = build_dataset(n_samples=200,
183+
n_features=100,
184+
n_informative_features=100)
181185
max_iter = 150
182186

183187
with warnings.catch_warnings():
184188
# Here we have a small number of iterations, and thus the
185189
# ElasticNet might not converge. This is to speed up tests
186190
warnings.simplefilter("ignore", UserWarning)
187-
clf = ElasticNetCV(n_alphas=5, eps=2e-3, l1_ratio=[0.9, 0.95], cv=3,
191+
clf = ElasticNetCV(n_alphas=5, eps=2e-3, l1_ratio=[0.5, 0.7], cv=3,
188192
max_iter=max_iter)
189193
clf.fit(X, y)
190-
assert_almost_equal(clf.alpha_, 0.002, 2)
191-
assert_equal(clf.l1_ratio_, 0.95)
192-
193-
clf = ElasticNetCV(n_alphas=5, eps=2e-3, l1_ratio=[0.9, 0.95], cv=3,
194+
# Well-conditionned settings, we should have selected our
195+
# smallest penalty
196+
assert_almost_equal(clf.alpha_, min(clf.alphas_))
197+
# Non-sparse ground truth: we should have seleted an elastic-net
198+
# that is closer to ridge B091 than to lasso
199+
assert_equal(clf.l1_ratio_, min(clf.l1_ratio))
200+
201+
clf = ElasticNetCV(n_alphas=5, eps=2e-3, l1_ratio=[0.5, 0.7], cv=3,
194202
max_iter=max_iter, precompute=True)
195203
clf.fit(X, y)
196-
assert_almost_equal(clf.alpha_, 0.002, 2)
197-
assert_equal(clf.l1_ratio_, 0.95)
198204

199-
# test set
205+
# Well-conditionned settings, we should have selected our
206+
# smallest penalty
207+
assert_almost_equal(clf.alpha_, min(clf.alphas_))
208+
# Non-sparse ground truth: we should have seleted an elastic-net
209+
# that is closer to ridge than to lasso
210+
assert_equal(clf.l1_ratio_, min(clf.l1_ratio))
211+
212+
# We are in well-conditionned settings with low noise: we should
213+
# have a good test-set performance
200214
assert_greater(clf.score(X_test, y_test), 0.99)
201215

202216

0 commit comments

Comments
 (0)
0