8000 [MRG] Fix GBDT init parameter when it's a pipeline by NicolasHug · Pull Request #13472 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] Fix GBDT init parameter when it's a pipeline #13472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 27, 2019

Conversation

NicolasHug
Copy link
Member

Reference Issues/PRs

Fixes #13466

What does this implement/fix? Explain your changes.

This PR fixes the support of the init estimator of GBDTs when init is a pipeline.

Note that pipeline do not support sample weights.

Any other comments?

@Thomasillo
Copy link

Perfect, thanks a lot!

@@ -1484,7 +1484,7 @@ def fit(self, X, y, sample_weight=None, monitor=None):
else:
try:
self.init_.fit(X, y, sample_weight=sample_weight)
except TypeError:
except (TypeError, ValueError):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like here it should be actually more strict catching than loose. A lot of (if not most of) init param validation happens in fit, and they raise a ValueError if the parameters are not valid, and here the user would instead see a message complaining about sample_weights which would be irrelevant.

Wouldn't checking the signature, or the [appropriate] estimator tag be a better idea here?

Copy link
Member Author
@NicolasHug NicolasHug Mar 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the user would instead see a message complaining about sample_weights which would be irrelevant.

Why is it irrelevant? This is precisely what this check is about.

Wouldn't checking the signature, or the [appropriate] estimator tag be a better idea here?

Yes I agree, but apparently using a try catch is preferred #12983 (comment)

EDIT: just saw that a supports_sample_weight tag is currently discussed #13438 but it's far from done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not elegant but I'm okay with this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, NuSVR supports sample_weights on fit:

X, y = make_regression(random_state=0)
NuSVR(nu=1.5).fit(X, y, sample_weight=np.ones(X.shape[0]))

gives:

ValueError: nu <= 0 or nu > 1

But after this PR:

init = make_pipeline(NuSVR(nu=1.5))
gb = GradientBoostingRegressor(init=init)
gb.fit(X, y, sample_weight=np.ones(X.shape[0]))

gives:

ValueError: The initial estimator Pipeline does not support sample weights.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooh you were talking about the input checking of the init estimator, ok, good point.

@jnothman do you think using has_fit_param() would be justified here? As far as I understand, this would only be a problem if a user passes a custom estimator which accepts sample_weight in fit() as a keyword args.

Another option would be to test the error message of the ValueError coming from a pipeline and only raise in this case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how has_fit_param helps here.

But perhaps we should raise a more equivocal error message ("could not fit init estimator with sample_weight") and use raise from to report the original exception

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how has_fit_param helps here.

This check is only supposed to check whether the init estimator supports samples_weights. I had to add ValueError for pipelines because unlike traditional estimators, they don't raise TypeError. As @adrinjalali noted, now the check also catches ValueError coming from other reasons (namely input checking).

Using has_fit_param would avoid this, I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using has_fit_param would avoid this, I think.

How so?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using has_fit_param avoids using a try except

@@ -1484,7 +1484,7 @@ def fit(self, X, y, sample_weight=None, monitor=None):
else:
try:
self.init_.fit(X, y, sample_weight=sample_weight)
except TypeError:
except (TypeError, ValueError):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not elegant but I'm okay with this.

@jnothman
Copy link
Member
jnothman commented Mar 22, 2019 via email

@NicolasHug
Copy link
Member Author

I updated the code. I hope it's clearer now.

if 'not enough values to unpack' in str(e): # pipeline
raise ValueError(msg)
else: # regular estimator whose input checking failed
raise e
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit:

except ValueError as e:
    if 'not enough values to unpack' in str(e):  # pipeline
        raise ValueError(msg)
    raise  # regular estimator whose input checking failed

else: # regular estimator whose input checking failed
raise e
raise
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Do not need the else here

Copy link
Member Author
@NicolasHug NicolasHug Mar 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally prefer the whole if/else logic. It's clearer, it doesn't rely on the fact that the above block exits, and has a more functional flavor.

Copy link
Member
@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit messy but okay I guess. Maybe we should just make Pipeline raise a TypeError (or something that's both Type and Value), though.

except TypeError: # regular estimator without SW support
raise ValueError(msg)
except ValueError as e:
if 'not enough values to unpack' in str(e): # pipeline
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather it if we improved the message for fit params missing __ in Pipeline, but okay

@NicolasHug
Copy link
Member Author

I think we'll be able to make it much better once we have a 'supports_sample_weight' tag

@adrinjalali
Copy link
Member

I think we'll be able to make it much better once we have a 'supports_sample_weight' tag

Could you then add a TODO or a XXX note so that at some point we do change it to exploit the tag once it's there?

@adrinjalali
Copy link
Member

@jnothman does your approval here stand? I'm not sure what you think about this one now.

@jnothman
Copy link
Member

Let's merge and then change when #13534 is fixed.

@jnothman jnothman merged commit d6b368e into scikit-learn:master Mar 27, 2019
@jnothman
Copy link
Member

Thanks @NicolasHug!

D16C
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
koenvandevelde pushed a commit to koenvandevelde/scikit-learn that referenced this pull request Jul 12, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

GradientBoostingRegressor initial estimator does not play together with Pipeline
5 participants
0