8000 Changed default argument of precompute in ElasticNet and Lasso · scikit-learn/scikit-learn@9904a04 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9904a04

Browse files
committed
Changed default argument of precompute in ElasticNet and Lasso
Setting precompute to "auto" was found to be slower when n_samples > n_features since the computation of the Gram matrix is computationally expensive and outweighs the benefit of fitting the Gram for just one alpha.
1 parent 8357f17 commit 9904a04

File tree

4 files changed

+23
-5
lines changed

4 files changed

+23
-5
lines changed

doc/modules/linear_model.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ for another implementation::
184184
>>> clf = linear_model.Lasso(alpha = 0.1)
185185
>>> clf.fit([[0, 0], [1, 1]], [0, 1])
186186
Lasso(alpha=0.1, copy_X=True, fit_intercept=True, max_iter=1000,
187-
normalize=False, positive=False, precompute='auto', random_state=None,
187+
normalize=False, positive=False, precompute=False, random_state=None,
188188
selection='cyclic', tol=0.0001, warm_start=False)
189189
>>> clf.predict([[1, 1]])
190190
array([ 0.8])

doc/tutorial/statistical_inference/supervised_learning.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ application of Occam's razor: *prefer simpler models*.
327327
>>> regr.alpha = best_alpha
328328
>>> regr.fit(diabetes_X_train, diabetes_y_train)
329329
Lasso(alpha=0.025118864315095794, copy_X=True, fit_intercept=True,
330-
max_iter=1000, normalize=False, positive=False, precompute='auto',
330+
max_iter=1000, normalize=False, positive=False, precompute=False,
331331
random_state=None, selection='cyclic', tol=0.0001, warm_start=False)
332332
>>> print(regr.coef_)
333333
[ 0. -212.43764548 517.19478111 313.77959962 -160.8303982 -0.

doc/whats_new.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,13 @@ API changes summary
132132
but previous versions accidentally returned only the positive
133133
probability. Fixed by Will Lamond and `Lars Buitinck`_.
134134

135+
- Change default value of precompute in :class:`ElasticNet` and :class:`Lasso`
136+
to False. Setting precompute to "auto" was found to be slower when
137+
n_samples > n_features since the computation of the Gram matrix is
138+
computationally expensive and outweighs the benefit of fitting the Gram
139+
for just one alpha.
140+
``precompute="auto"`` is now deprecated and will be removed in 0.18
141+
By `Manoj Kumar`_.
135142

136143
.. _changes_0_15_2:
137144

sklearn/linear_model/coordinate_descent.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,8 @@ class ElasticNet(LinearModel, RegressorMixin):
604604
calculations. If set to ``'auto'`` let us decide. The Gram
605605
matrix can also be passed as argument. For sparse input
606606
this option is always ``True`` to preserve sparsity.
607+
WARNING : The ``'auto'`` option is deprecated and will
608+
be removed in 0.18.
607609
608610
max_iter : int, optional
609611
The maximum number of iterations
@@ -665,7 +667,7 @@ class ElasticNet(LinearModel, RegressorMixin):
665667
path = staticmethod(enet_path)
666668

667669
def __init__(self, alpha=1.0, l1_ratio=0.5, fit_intercept=True,
668-
normalize=False, precompute='auto', max_iter=1000,
670+
normalize=False, precompute=False, max_iter=1000,
669671
copy_X=True, tol=1e-4, warm_start=False, positive=False,
670672
random_state=None, selection='cyclic'):
671673
self.alpha = alpha
@@ -708,6 +710,13 @@ def fit(self, X, y):
708710
warnings.warn("With alpha=0, this algorithm does not converge "
709711
"well. You are advised to use the LinearRegression "
710712
"estimator", stacklevel=2)
713+
714+
if self.precompute == 'auto':
715+
warnings.warn("Setting precompute to 'auto', has found to be "
716+
"slower even when n_samples > n_features. Hence "
717+
"it will be removed in 0.18.",
718+
DeprecationWarning, stacklevel=2)
719+
711720
X = check_array(X, 'csc', dtype=np.float64, order='F', copy=self.copy_X
712721
and self.fit_intercept)
713722
# From now on X can be touched inplace
@@ -830,6 +839,8 @@ class Lasso(ElasticNet):
830839
calculations. If set to ``'auto'`` let us decide. The Gram
831840
matrix can also be passed as argument. For sparse input
832841
this option is always ``True`` to preserve sparsity.
842+
WARNING : The ``'auto'`` option is deprecated and will
843+
be removed in 0.18.
833844
834845
max_iter : int, optional
835846
The maximum number of iterations
@@ -880,7 +891,7 @@ class Lasso(ElasticNet):
880891
>>> clf = linear_model.Lasso(alpha=0.1)
881892
>>> clf.fit([[0,0], [1, 1], [2, 2]], [0, 1, 2])
882893
Lasso(alpha=0.1, copy_X=True, fit_intercept=True, max_iter=1000,
883-
normalize=False, positive=False, precompute='auto', random_state=None,
894+
normalize=False, positive=False, precompute=False, random_state=None,
884895
selection='cyclic', tol=0.0001, warm_start=False)
885896
>>> print(clf.coef_)
886897
[ 0.85 0. ]
@@ -906,7 +917,7 @@ class Lasso(ElasticNet):
906917
path = staticmethod(enet_path)
907918

908919
def __init__(self, alpha=1.0, fit_intercept=True, normalize=False,
909-
precompute='auto', copy_X=True, max_iter=1000,
920+
precompute=False, copy_X=True, max_iter=1000,
910921
tol=1e-4, warm_start=False, positive=False,
911922
random_state=None, selection='cyclic'):
912923
super(Lasso, self).__init__(

0 commit comments

Comments
 (0)
0