8000 FIX : fixing Lars lasso with early stopping using alph_min + adding t… · seckcoder/scikit-learn@0bf053f · GitHub
[go: up one dir, main page]

Skip to content

Commit 0bf053f

Browse files
committed
FIX : fixing Lars lasso with early stopping using alph_min + adding test for it
1 parent 8876d34 commit 0bf053f

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

scikits/learn/linear_model/least_angle.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,17 @@ def lars_path(X, y, Xy=None, Gram=None, max_features=None,
120120

121121
alphas[n_iter] = C / n_samples
122122

123-
if (C < alpha_min) or (n_active == max_features):
123+
# Check for early stopping
124+
if alphas[n_iter] < alpha_min: # interpolate
125+
# interpolation factor 0 <= ss < 1
126+
ss = (alphas[n_iter-1] - alpha_min) / (alphas[n_iter-1] -
127+
alphas[n_iter])
128+
coefs[n_iter] = coefs[n_iter-1] + ss*(coefs[n_iter] -
129+
coefs[n_iter-1])
130+
alphas[n_iter] = alpha_min
131+
break
132+
133+
if n_active == max_features:
124134
break
125135

126136
if not drop:
@@ -270,13 +280,6 @@ def lars_path(X, y, Xy=None, Gram=None, max_features=None,
270280
if verbose:
271281
print "%s\t\t%s\t\t%s\t\t%s\t\t%s" % (n_iter, '', drop_idx,
272282
n_active, abs(temp))
273-
if alphas[n_iter] < alpha_min: # interpolate
274-
# interpolation factor 0 <= ss < 1
275-
ss = (alphas[n_iter-1] - alpha_min) / (alphas[n_iter-1] -
276-
alphas[n_iter])
277-
coefs[n_iter] = coefs[n_iter-1] + ss*(coefs[n_iter] - coefs[n_iter-1])
278-
alphas[n_iter] = alpha_min
279-
280283

281284
# resize coefs in case of early stop
282285
alphas = alphas[:n_iter+1]

scikits/learn/linear_model/tests/test_least_angle.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,22 @@ def test_lasso_lars_vs_lasso_cd(verbose=False):
111111
error = np.linalg.norm(c - lasso_cd.coef_)
112112
assert error < 0.01
113113

114+
def test_lasso_lars_vs_lasso_cd_early_stopping(verbose=False):
115+
"""
116+
Test that LassoLars and Lasso using coordinate descent give the
117+
same results when early stopping is used.
118+
(test : before, in the middle, and in the last part of the path)
119+
"""
120+
alphas_min = [10, 0.9, 1e-4]
121+
for alphas_min in alphas_min:
122+
alphas, _, lasso_path = linear_model.lars_path(X, y, method='lasso',
123+
alpha_min=0.9)
124+
lasso_cd = linear_model.Lasso(fit_intercept=False)
125+
lasso_cd.alpha = alphas[-1]
126+
lasso_cd.fit(X, y, tol=1e-8)
127+
error = np.linalg.norm(lasso_path[:,-1] - lasso_cd.coef_)
128+
assert error < 0.01
129+
114130

115131
if __name__ == '__main__':
116132
import nose

0 commit comments

Comments
 (0)
0