-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
SLEP006: ClassifierChain and RegressorChain routing #24443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SLEP006: ClassifierChain and RegressorChain routing #24443
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@@ -646,7 +646,7 @@ def fit(self, X, y, **fit_params): | |||
self.sample_weight_ = fit_params["sample_weight"] | |||
super().fit(X, y, **fit_params) | |||
|
|||
model = RegressorChain(MySGD()) | |||
model = RegressorChain(MySGD().set_fit_request(sample_weight=True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need a test that checks that the existing behaviour still works but issues a DeprecationWarning
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the common tests are supposed to do that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR!
@@ -655,6 +656,15 @@ def fit(self, X, y, **fit_params): | |||
for est in model.estimators_: | |||
assert est.sample_weight_ is weight | |||
|
|||
# Test that the existing behavior works and raises a FutureWarning |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make it easier to search for when we need to remove the warning:
# Test that the existing behavior works and raises a FutureWarning | |
# TODO(1.4): Remove check for FutureWarning | |
# Test that the existing behavior works and raises a FutureWarning |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @OmarManzoor
@@ -646,7 +646,7 @@ def fit(self, X, y, **fit_params): | |||
self.sample_weight_ = fit_params["sample_weight"] | |||
super().fit(X, y, **fit_params) | |||
|
|||
model = RegressorChain(MySGD()) | |||
model = RegressorChain(MySGD().set_fit_request(sample_weight=True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the common tests are supposed to do that.
sklearn/multioutput.py
Outdated
estimator=self.base_estimator, | ||
method_mapping=MethodMapping().add(callee="fit", caller="fit"), | ||
) | ||
.warn_on(child="estimator", method="fit", params=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method (get_metadata_routing
) needs to be defined in the child classes because one of them (RegressorChain) already accepts and passes fit_params
(hence the warning here correct), but the ClassifierChain doesn't and passing fit_params
is added in this PR and therefore shouldn't add this warning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise LGTM.
Reference Issues/PRs
Towards: #22893
What does this implement/fix? Explain your changes.
Any other comments?
None