8000 [MRG+1] Fix Error: 'MultiTaskLasso' object has no attribute 'coef_' when warm_start = True by joaak · Pull Request #12361 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+1] Fix Error: 'MultiTaskLasso' object has no attribute 'coef_' when warm_start = True #12361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new/v0.21.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ Support for Python 3.4 and below has been officially dropped.
- An entry goes here
- An entry goes here

:mod:`sklearn.linear_model`
...........................
- |Fix| Fixed a bug in :class:`linear_model.MultiTaskElasticNet` which was breaking
``MultiTaskElasticNet`` and ``MultiTaskLasso`` when ``warm_start = True``. :issue:`12360`
by :user:`Aakanksha Joshi <joaak>`.

:mod:`sklearn.cluster`
......................

Expand Down
2 changes: 1 addition & 1 deletion sklearn/linear_model/coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,7 +1795,7 @@ def fit(self, X, y):
X, y, X_offset, y_offset, X_scale = _preprocess_data(
X, y, self.fit_intercept, self.normalize, copy=False)

if not self.warm_start or self.coef_ is None:
if not self.warm_start or not hasattr(self, "coef_"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be safe, we might want to handle the case where coef_ is None in case someone relied on this broken implementation...?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this condition

Suggested change
if not self.warm_start or not hasattr(self, "coef_"):
if not self.warm_start or not hasattr(self, "coef_") or self.coef_ is None:

should be able to handle that?

self.coef_ = np.zeros((n_tasks, n_features), dtype=X.dtype.type,
order='F')

Expand Down
12 changes: 12 additions & 0 deletions sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,3 +818,15 @@ def test_coef_shape_not_zero():
est_no_intercept = Lasso(fit_intercept=False)
est_no_intercept.fit(np.c_[np.ones(3)], np.ones(3))
assert est_no_intercept.coef_.shape == (1,)


def test_warm_start_multitask_lasso():
X, y, X_test, y_test = build_dataset()
Y = np.c_[y, y]
clf = MultiTaskLasso(alpha=0.1, max_iter=5, warm_start=True)
ignore_warnings(clf.fit)(X, Y)
ignore_warnings(clf.fit)(X, Y) # do a second round with 5 iterations

clf2 = MultiTaskLasso(alpha=0.1, max_iter=10)
ignore_warnings(clf2.fit)(X, Y)
assert_array_almost_equal(clf2.coef_, clf.coef_)
0