8000 [MRG+1] Fix Error: 'MultiTaskLasso' object has no attribute 'coef_' when warm_start = True by joaak · Pull Request #12361 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+1] Fix Error: 'MultiTaskLasso' object has no attribute 'coef_' when warm_start = True #12361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 21, 2018

Conversation

joaak
Copy link
Contributor
@joaak joaak commented Oct 12, 2018

Reference Issues/PRs

Fixes #12360

What does this implement/fix? Explain your changes.

In the original source code, within the class MultiTaskElaticNet, if not self.warm_start or self.coef_ is None: (line 1798) throws an error that the object has no attribute 'coef_' when a MultiTaskElaticNet/MultiTaskLasso object is created with warm_start = True.

Potential Fix:

Replace

if not self.warm_start or self.coef_ is None:

with

if not self.warm_start or not hasattr(self, "coef_"):

(The replacement is what has been done in line 733 within the ElasticNet class in the original source code).

@jnothman
Copy link
Member

Please add a non-regression test

@joaak
Copy link
Contributor Author
joaak commented Oct 15, 2018

AppVeyor build seems to be failing for Environment: PYTHON=C:\Python27, PYTHON_VERSION=2.7.8, PYTHON_ARCH=32 and for a test called test_count_nonzero.

I'm not sure if this was triggered by the test I added. Please let me know if there's anything required on my side to rectify this.

Copy link
Member
@agramfort agramfort left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this fix should be documented in what's new. thx @joaak

Copy link
Member
@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM

@@ -1795,7 +1795,7 @@ def fit(self, X, y):
X, y, X_offset, y_offset, X_scale = _preprocess_data(
X, y, self.fit_intercept, self.normalize, copy=False)

if not self.warm_start or self.coef_ is None:
if not self.warm_start or not hasattr(self, "coef_"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be safe, we might want to handle the case where coef_ is None in case someone relied on this broken implementation...?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this condition

Suggested change
if not self.warm_start or not hasattr(self, "coef_"):
if not self.warm_start or not hasattr(self, "coef_") or self.coef_ is None:

should be able to handle that?

@agramfort agramfort changed the title Fix Error: 'MultiTaskLasso' object has no attribute 'coef_' when warm_start = True [MRG+1] Fix Error: 'MultiTaskLasso' object has no attribute 'coef_' when warm_start = True Oct 16, 2018
@agramfort
Copy link
Member

got to go from my end

@jnothman
Copy link
Member
jnothman commented Oct 18, 2018 via email

@agramfort
Copy link
Member

what's new is updated and test is there.

@jnothman ok to merge?

@jnothman
Copy link
Member

Thanks @joaak

@jnothman jnothman merged commit 3c76b9c into scikit-learn:master Oct 21, 2018
@jnothman jnothman added this to the 0.20.3 milestone Jan 15, 2019
jnothman pushed a commit to jnothman/scikit-learn that referenced this pull request Feb 19, 2019
@jnothman jnothman mentioned this pull request Feb 19, 2019
17 tasks
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
koenvandevelde pushed a commit to koenvandevelde/scikit-learn that referenced this pull request Jul 12, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Error: 'MultiTaskLasso' object has no attribute 'coef_' when warm_start = True
3 participants
0