From ae7a4ad0a936870b24efeb6075f5a29dd103e50e Mon Sep 17 00:00:00 2001 From: Manoj-Kumar-S Date: Wed, 4 Jun 2014 00:47:18 +0530 Subject: [PATCH] Remove unused param precompute from MultiTask models --- sklearn/linear_model/coordinate_descent.py | 30 +++++++++------------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/sklearn/linear_model/coordinate_descent.py b/sklearn/linear_model/coordinate_descent.py index 89fbc093e4edb..c4b378f74e633 100644 --- a/sklearn/linear_model/coordinate_descent.py +++ b/sklearn/linear_model/coordinate_descent.py @@ -889,7 +889,13 @@ def _path_residuals(X, y, train, test, path, path_params, alphas=None, y_test = y[test] fit_intercept = path_params['fit_intercept'] normalize = path_params['normalize'] - precompute = path_params['precompute'] + + if y.ndim == 1: + precompute = path_params['precompute'] + else: + # No Gram variant of multi-task exists right now. + # Fall back to default enet_multitask + precompute = False X_train, y_train, X_mean, y_mean, X_std, precompute, Xy = \ _pre_fit(X_train, y_train, None, precompute, normalize, fit_intercept, @@ -1638,11 +1644,6 @@ class MultiTaskElasticNetCV(LinearModelCV, RegressorMixin): List of alphas where to compute the models. If not provided, set automatically. - precompute : True | False | 'auto' | array-like - Whether to use a precomputed Gram matrix to speed up - calculations. If set to ``'auto'`` let us decide. The Gram - matrix can also be passed as argument. - n_alphas : int, optional Number of alphas along the regularization path @@ -1716,8 +1717,7 @@ class MultiTaskElasticNetCV(LinearModelCV, RegressorMixin): ... #doctest: +NORMALIZE_WHITESPACE MultiTaskElasticNetCV(alphas=None, copy_X=True, cv=None, eps=0.001, fit_intercept=True, l1_ratio=0.5, max_iter=1000, n_alphas=100, - n_jobs=1, normalize=False, precompute='auto', tol=0.0001, - verbose=0) + n_jobs=1, normalize=False, tol=0.0001, verbose=0) >>> print(clf.coef_) [[ 0.52875032 0.46958558] [ 0.52875032 0.46958558]] @@ -1740,7 +1740,7 @@ class MultiTaskElasticNetCV(LinearModelCV, RegressorMixin): path = staticmethod(enet_path) def __init__(self, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, - fit_intercept=True, normalize=False, precompute='auto', + fit_intercept=True, normalize=False, max_iter=1000, tol=1e-4, cv=None, copy_X=True, verbose=0, n_jobs=1): self.l1_ratio = l1_ratio @@ -1749,7 +1749,6 @@ def __init__(self, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, self.alphas = alphas self.fit_intercept = fit_intercept self.normalize = normalize - self.precompute = precompute self.max_iter = max_iter self.tol = tol self.cv = cv @@ -1781,11 +1780,6 @@ class MultiTaskLassoCV(LinearModelCV, RegressorMixin): List of alphas where to compute the models. If not provided, set automaticlly. - precompute : True | False | 'auto' | array-like - Whether to use a precomputed Gram matrix to speed up - calculations. If set to ``'auto'`` let us decide. The Gram - matrix can also be passed as argument. - n_alphas : int, optional Number of alphas along the regularization path @@ -1856,10 +1850,10 @@ class MultiTaskLassoCV(LinearModelCV, RegressorMixin): path = staticmethod(lasso_path) def __init__(self, eps=1e-3, n_alphas=100, alphas=None, fit_intercept=True, - normalize=False, precompute='auto', max_iter=1000, tol=1e-4, - copy_X=True, cv=None, verbose=False, n_jobs=1): + normalize=False, max_iter=1000, tol=1e-4, copy_X=True, + cv=None, verbose=False, n_jobs=1): super(MultiTaskLassoCV, self).__init__( eps=eps, n_alphas=n_alphas, alphas=alphas, fit_intercept=fit_intercept, normalize=normalize, - precompute=precompute, max_iter=max_iter, tol=tol, copy_X=copy_X, + max_iter=max_iter, tol=tol, copy_X=copy_X, cv=cv, verbose=verbose, n_jobs=n_jobs)