From 525db2bac1090627588e67f8221bbd67f9053272 Mon Sep 17 00:00:00 2001 From: joaak <29533036+joaak@users.noreply.github.com> Date: Thu, 11 Oct 2018 23:41:02 -0400 Subject: [PATCH 1/3] update to fix .coef_ issue for MultiTaskLasso --- sklearn/linear_model/coordinate_descent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/linear_model/coordinate_descent.py b/sklearn/linear_model/coordinate_descent.py index 86d621b415b3a..1c65a650f8bc2 100644 --- a/sklearn/linear_model/coordinate_descent.py +++ b/sklearn/linear_model/coordinate_descent.py @@ -1795,7 +1795,7 @@ def fit(self, X, y): X, y, X_offset, y_offset, X_scale = _preprocess_data( X, y, self.fit_intercept, self.normalize, copy=False) - if not self.warm_start or self.coef_ is None: + if not self.warm_start or not hasattr(self, "coef_"): self.coef_ = np.zeros((n_tasks, n_features), dtype=X.dtype.type, order='F') From 94ad80edd295db5f3b62bdcd310d6c99520823e3 Mon Sep 17 00:00:00 2001 From: joaak <29533036+joaak@users.noreply.github.com> Date: Sun, 14 Oct 2018 19:49:56 -0400 Subject: [PATCH 2/3] add non-regression test for MultiTaskLasso with warm_start = True --- .../linear_model/tests/test_coordinate_descent.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index 834d685f5b23d..1001300cf643f 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -818,3 +818,15 @@ def test_coef_shape_not_zero(): est_no_intercept = Lasso(fit_intercept=False) est_no_intercept.fit(np.c_[np.ones(3)], np.ones(3)) assert est_no_intercept.coef_.shape == (1,) + + +def test_warm_start_multitask_lasso(): + X, y, X_test, y_test = build_dataset() + Y = np.c_[y, y] + clf = MultiTaskLasso(alpha=0.1, max_iter=5, warm_start=True) + ignore_warnings(clf.fit)(X, Y) + ignore_warnings(clf.fit)(X, Y) # do a second round with 5 iterations + + clf2 = MultiTaskLasso(alpha=0.1, max_iter=10) + ignore_warnings(clf2.fit)(X, Y) + assert_array_almost_equal(clf2.coef_, clf.coef_) From 16dc49dfd71b3c3803fd923aa59682c381da5e85 Mon Sep 17 00:00:00 2001 From: joaak <29533036+joaak@users.noreply.github.com> Date: Mon, 15 Oct 2018 17:28:09 -0400 Subject: [PATCH 3/3] add fix for MultiTaskLasso with warm_start = True --- doc/whats_new/v0.21.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index 87352b6f8e1f0..0c8a56e9dcaf0 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -40,6 +40,12 @@ Support for Python 3.4 and below has been officially dropped. - An entry goes here - An entry goes here +:mod:`sklearn.linear_model` +........................... +- |Fix| Fixed a bug in :class:`linear_model.MultiTaskElasticNet` which was breaking + ``MultiTaskElasticNet`` and ``MultiTaskLasso`` when ``warm_start = True``. :issue:`12360` + by :user:`Aakanksha Joshi `. + :mod:`sklearn.cluster` ......................