8000 Merge pull request #168 from yarikoptic/0.8.X · jwchennlp/scikit-learn@7eae6e0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7eae6e0

Browse files
committed
Merge pull request scikit-learn#168 from yarikoptic/0.8.X
0.8.x: FIX: lars_path -- assure that at least some features get added if necessary
2 parents 67ff4ef + 8ddf771 commit 7eae6e0

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

scikits/learn/linear_model/least_angle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def lars_path(X, y, Xy=None, Gram=None, max_features=None,
222222

223223
if n_iter >= coefs.shape[0]:
224224
# resize the coefs and alphas array
225-
add_features = 2 * (max_features - n_active)
225+
add_features = 2 * max(1, (max_features - n_active))
226226
coefs.resize((n_iter + add_features, n_features))
227227
alphas.resize(n_iter + add_features)
228228

scikits/learn/linear_model/tests/test_least_angle.py

Lines changed: 15 additions & 0 deletions
127
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,21 @@ def test_lasso_lars_vs_lasso_cd_early_stopping(verbose=False):
127
error = np.linalg.norm(lasso_path[:,-1] - lasso_cd.coef_)
128128
assert error < 0.01
129129

130+
def test_lars_add_features(verbose=False):
131+
linear_model.LARS(verbose=verbose, fit_intercept=True).fit(
132+
np.array([[ 0.02863763, 0.88144085, -0.02052429, -0.10648066, -0.06396584, -0.18338974],
133+
[ 0.02038287, 0.51463335, -0.31734681, -0.12830467, 0.16870657, 0.02169503],
134+
[ 0.14411476, 0.37666599, 0.2764702 , 0.0723859 , -0.03812009, 0.03663579],
135+
[-0.29411448, 0.33321005, 0.09429278, -0.10635334, 0.02827505, -0.07307312],
136+
[-0.40929514, 0.57692643, -0.12559217, 0.19001991, 0.07381565, -0.0072319 ],
137+
[-0.01763028, 1. , 0.04437242, 0.11870747, 0.1235008 , -0.27375014],
138+
[-0.06482493, 0.1233536 , 0.15686536, 0.02059646, -0.31723546, 0.42050836],
139+
[-0.18806577, 0.01970053, 0.02258482, -0.03216307, 0.17196751, 0.34123213],
140+
[ 0.11277307, 0.15590351, 0.11231502, 0.22009306, 0.1811108 , 0.51456405],
141+
[ 0.03228484, -0.12317732, -0.34223564, 0.08323492, -0.15770904, 0.39392212],
142+
[-0.00586796, 0.04902901, 0.18020746, 0.04370165, -0.06686751, 0.50099547],
143+
[-0.12951744, 0.21978613, -0.04762174, -0.27227304, -0.02722684, 0.57449581]]),
144+
np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]))
130145

131146
if __name__ == '__main__':
132147
import nose

0 commit comments

Comments
 (0)
0