8000 Made the following changes · scikit-learn/scikit-learn@a5fe51f · GitHub
[go: up one dir, main page]

Skip to content

Commit a5fe51f

Browse files
committed
Made the following changes
a] Raise ValueError for invalid precompute b] Remove precompute for MultiTask ENet/LassoCV
1 parent 9a31569 commit a5fe51f

File tree

2 files changed

+23
-19
lines changed

2 files changed

+23
-19
lines changed

sklearn/linear_model/coordinate_descent.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -474,9 +474,12 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
474474
model = cd_fast.enet_coordinate_descent_gram(
475475
coef_, l1_reg, l2_reg, precompute, Xy, y, max_iter,
476476
tol, positive)
477-
else:
477+
elif precompute is False:
478478
model = cd_fast.enet_coordinate_descent(
479479
coef_, l1_reg, l2_reg, X, y, max_iter, tol, positive)
480+
else:
481+
raise ValueError("Precompute should be one of True, False, "
482+
"'auto' or array-like")
480483
coef_, dual_gap_, eps_ = model
481484
coefs[..., i] = coef_
482485
dual_gaps[i] = dual_gap_
@@ -893,7 +896,13 @@ def _path_residuals(X, y, train, test, path, path_params, alphas=None,
893896
y_test = y[test]
894897
fit_intercept = path_params['fit_intercept']
895898
normalize = path_params['normalize']
896-
precompute = path_params['precompute']
899+
900+
if y.ndim == 1:
901+
precompute = path_params['precompute']
902+
else:
903+
# No Gram variant of multi-task exists right now.
904+
# Fall back to default enet_multitask
905+
precompute = False
897906

898907
X_train, y_train, X_mean, y_mean, X_std, precompute, Xy = \
899908
_pre_fit(X_train, y_train, None, precompute, normalize, fit_intercept,
@@ -1642,11 +1651,6 @@ class MultiTaskElasticNetCV(LinearModelCV, RegressorMixin):
16421651
List of alphas where to compute the models.
16431652
If not provided, set automatically.
16441653
1645-
precompute : True | False | 'auto' | array-like
1646-
Whether to use a precomputed Gram matrix to speed up
1647-
calculations. If set to ``'auto'`` let us decide. The Gram
1648-
matrix can also be passed as argument.
1649-
16501654
n_alphas : int, optional
16511655
Number of alphas along the regularization path
16521656
@@ -1720,8 +1724,7 @@ class MultiTaskElasticNetCV(LinearModelCV, RegressorMixin):
17201724
... #doctest: +NORMALIZE_WHITESPACE
17211725
MultiTaskElasticNetCV(alphas=None, copy_X=True, cv=None, eps=0.001,
17221726
fit_intercept=True, l1_ratio=0.5, max_iter=1000, n_alphas=100,
1723-
n_jobs=1, normalize=False, precompute='auto', tol=0.0001,
1724-
verbose=0)
1727+
n_jobs=1, normalize=False, tol=0.0001, verbose=0)
17251728
>>> print(clf.coef_)
17261729
[[ 0.52875032 0.46958558]
17271730
[ 0.52875032 0.46958558]]
@@ -1744,7 +1747,7 @@ class MultiTaskElasticNetCV(LinearModelCV, RegressorMixin):
17441747
path = staticmethod(enet_path)
17451748

17461749
def __init__(self, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
1747-
fit_intercept=True, normalize=False, precompute='auto',
1750+
fit_intercept=True, normalize=False,
17481751
max_iter=1000, tol=1e-4, cv=None, copy_X=True,
17491752
verbose=0, n_jobs=1):
17501753
self.l1_ratio = l1_ratio
@@ -1753,7 +1756,6 @@ def __init__(self, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
17531756
self.alphas = alphas
17541757
self.fit_intercept = fit_intercept
17551758
self.normalize = normalize
1756-
self.precompute = precompute
17571759
self.max_iter = max_iter
17581760
self.tol = tol
17591761
self.cv = cv
@@ -1785,11 +1787,6 @@ class MultiTaskLassoCV(LinearModelCV, RegressorMixin):
17851787
List of alphas where to compute the models.
17861788
If not provided, set automaticlly.
17871789
1788-
precompute : True | False | 'auto' | array-like
1789-
Whether to use a precomputed Gram matrix to speed up
1790-
calculations. If set to ``'auto'`` let us decide. The Gram
1791-
matrix can also be passed as argument.
1792-
17931790
n_alphas : int, optional
17941791
Number of alphas along the regularization path
17951792
@@ -1860,10 +1857,10 @@ class MultiTaskLassoCV(LinearModelCV, RegressorMixin):
18601857
path = staticmethod(lasso_path)
18611858

18621859
def __init__(self, eps=1e-3, n_alphas=100, alphas=None, fit_intercept=True,
1863-
normalize=False, precompute='auto', max_iter=1000, tol=1e-4,
1864-
copy_X=True, cv=None, verbose=False, n_jobs=1):
1860+
normalize=False, max_iter=1000, tol=1e-4, copy_X=True,
1861+
cv=None, verbose=False, n_jobs=1):
18651862
super(MultiTaskLassoCV, self).__init__(
18661863
eps=eps, n_alphas=n_alphas, alphas=alphas,
18671864
fit_intercept=fit_intercept, normalize=normalize,
1868-
precompute=precompute, max_iter=max_iter, tol=tol, copy_X=copy_X,
1865+
max_iter=max_iter, tol=tol, copy_X=copy_X,
18691866
cv=cv, verbose=verbose, n_jobs=n_jobs)

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,13 @@ def test_sparse_input_dtype_enet_and_lassocv():
460460
assert_almost_equal(clf.coef_, clf1.coef_, decimal=6)
461461

462462

463+
def test_precompute_invalid_argument():
464+
X, y, _, _ = build_dataset()
465+
for clf in [ElasticNetCV(precompute="invalid"),
466+
LassoCV(precompute="invalid")]:
467+
assert_raises(ValueError, clf.fit, X, y)
468+
469+
463470
if __name__ == '__main__':
464471
import nose
465472
nose.runmodule()

0 commit comments

Comments
 (0)
0