8000 SLEP006: ClassifierChain and RegressorChain routing by OmarManzoor · Pull Request #24443 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

SLEP006: ClassifierChain and RegressorChain routing #24443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

OmarManzoor
Copy link
Contributor

Reference Issues/PRs

Towards: #22893

What does this implement/fix? Explain your changes.

  • Added meta data routing to ClassifierChain and RegressorChain meta estimator's fit methods
  • Added the main meta data routing code in _BaseChain which is then inherited by the ClassifierChain and RegressorChain.
  • Updated tests to account for the new additions

Any other comments?

None

@OmarManzoor OmarManzoor changed the base branch from main to sample-props September 15, 2022 14:27
@OmarManzoor OmarManzoor changed the title Classifier and regression chain routing SLEP006: Classifier and regression chain routing Sep 15, 2022
@OmarManzoor OmarManzoor changed the title SLEP006: Classifier and regression chain routing SLEP006: ClassifierChain and RegressorChain routing Sep 15, 2022
Copy link
Member
@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@@ -646,7 +646,7 @@ def fit(self, X, y, **fit_params):
self.sample_weight_ = fit_params["sample_weight"]
super().fit(X, y, **fit_params)

model = RegressorChain(MySGD())
model = RegressorChain(MySGD().set_fit_request(sample_weight=True))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a test that checks that the existing behaviour still works but issues a DeprecationWarning.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the common tests are supposed to do that.

Copy link
Member
@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR!

@@ -655,6 +656,15 @@ def fit(self, X, y, **fit_params):
for est in model.estimators_:
assert est.sample_weight_ is weight

# Test that the existing behavior works and raises a FutureWarning
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make it easier to search for when we need to remove the warning:

Suggested change
# Test that the existing behavior works and raises a FutureWarning
# TODO(1.4): Remove check for FutureWarning
# Test that the existing behavior works and raises a FutureWarning

Copy link
Member
@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @OmarManzoor

@@ -646,7 +646,7 @@ def fit(self, X, y, **fit_params):
self.sample_weight_ = fit_params["sample_weight"]
super().fit(X, y, **fit_params)

model = RegressorChain(MySGD())
model = RegressorChain(MySGD().set_fit_request(sample_weight=True))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the common tests are supposed to do that.

estimator=self.base_estimator,
method_mapping=MethodMapping().add(callee="fit", caller="fit"),
)
.warn_on(child="estimator", method="fit", params=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method (get_metadata_routing) needs to be defined in the child classes because one of them (RegressorChain) already accepts and passes fit_params (hence the warning here correct), but the ClassifierChain doesn't and passing fit_params is added in this PR and therefore shouldn't add this warning.

Copy link
Member
@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM.

@adrinjalali adrinjalali merged commit 1a4f9bd into scikit-learn:sample-props Oct 4, 2022
@OmarManzoor OmarManzoor deleted the classifier_and_regression_chain_routing branch November 25, 2022 07:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0