10000 FEA Add metadata routing through predict methods of BaggingClassifier and BaggingRegressor by StefanieSenger · Pull Request #30833 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FEA Add metadata routing through predict methods of BaggingClassifier and BaggingRegressor #30833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 18, 2025

Conversation

StefanieSenger
Copy link
Contributor
@StefanieSenger StefanieSenger commented Feb 14, 2025

Reference Issues/PRs

closes #30808
towards #22893

What does this implement/fix? Explain your changes.

Add metadata routing functionality to BaggingClassifier and BaggingRegressor's predict, predict_proba predict_log_proba and decision_function methods.

CC @adrinjalali @OmarManzoor
Would you like to have a look?

Copy link
github-actions bot commented Feb 14, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: f235bcd. Link to the linter CI: here

Comment on lines -301 to 304
y_proba = np.empty(shape=(len(X), 2))
y_proba[: len(X) // 2, :] = np.asarray([1.0, 0.0])
y_proba[len(X) // 2 :, :] = np.asarray([0.0, 1.0])
y_proba = np.empty(shape=(len(X), len(self.classes_)), dtype=np.float32)
# each row sums up to 1.0:
y_proba[:] = np.random.dirichlet(alpha=np.ones(len(self.classes_)), size=len(X))
return y_proba
Copy link
Contributor Author
@StefanieSenger StefanieSenger Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is necessary to make sure predict_proba and predict_log_proba return a column per class for all the test classes (ConsumingClassifier and NonConsumingClassifier) to avoid shape mismatch while testing.

@StefanieSenger StefanieSenger changed the title FEA Add metadata routing trough predict methods of Bagging* FEA Add metadata routing through predict methods of Bagging* Feb 14, 2025
@StefanieSenger StefanieSenger changed the title FEA Add metadata routing through predict methods of Bagging* FEA Add metadata routing through predict methods of BaggingClassifier and BaggingRegressor Feb 14, 2025
@StefanieSenger StefanieSenger marked this pull request as ready for review February 15, 2025 14:38
Copy link
Contributor
@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments otherwise looks good

@OmarManzoor OmarManzoor added the Waiting for Second Reviewer First reviewer is done, need a second one! label Feb 17, 2025
Copy link
Contributor Author
@StefanieSenger StefanieSenger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your review, @OmarManzoor!
I have added a comment on the conditions for the router.

Copy link
Contributor
@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @StefanieSenger

Co-authored-by: Omar Salman <omar.salman2007@gmail.com>
Copy link
Member
@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, but a few things need fixing. This is not a complete review. Please ping me once you're ready for another round.

Copy link
Contributor Author
@StefanieSenger StefanieSenger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for reviewing, @adrinjalali!

I have changed the naming of the routed kwargs into **params, added documentation on the routing depending on which methods are available in the sub-estimators, moved the changelog entry into the correct section, re-defined the router depending on whether predict_log_proba is available and added corresponding test cases.

This PR is ready for another round of reviewing.

@StefanieSenger StefanieSenger added Metadata Routing all issues related to metadata routing, slep006, sample props and removed Waiting for Second Reviewer First reviewer is done, need a second one! labels Feb 18, 2025
# `sample_weight` is passed to the respective methods dynamically at
# runtime:
if hasattr(self._get_estimator(), "predict_proba"):
method_mapping.add(caller="predict_log_proba", callee="predict_proba")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like there's no test case for this one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is because we need to test it with a classifier that does not have predict_log_proba, but has predict_proba.

I have added an "outsourced" test just as I did for the case when the sub-classifier doesn't have predict_proba.

routed_params = process_routing(self, "predict_proba", **params)
else:
routed_params = Bunch()
routed_params.estimator = Bunch(predict_proba=params)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't really support routing metadata around if metadata routing is disabled anyway (the _raise_for_params above). So here the Bunch can/should be empty.

**params : dict
Parameters routed to the `predict_log_proba`, the `predict_proba` or the
`proba` method of the sub-estimators via the metadata routing API. The
routing is tried in the before mentioned order depending on whether this
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
9E81
routing is tried in the before mentioned order depending on whether this
routing is tried in the mentioned order depending on whether this

routed_params = process_routing(self, "predict_log_proba", **params)
else:
routed_params = Bunch()
routed_params.estimator = Bunch(predict_log_proba=params)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here with not routing params

routed_params = process_routing(self, "decision_function", **params)
else:
routed_params = Bunch()
routed_params.estimator = Bunch(decision_function=params)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here

routed_params = process_routing(self, "predict", **params)
else:
routed_params = Bunch()
routed_params.estimator = Bunch(predict=params)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here

Copy link
Contributor Author
@StefanieSenger StefanieSenger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review, @adrinjalali.

I have removed all the passed params from the bunches without metadata routing and added test cases and new test classes for dynamic method selection.

# `sample_weight` is passed to the respective methods dynamically at
# runtime:
if hasattr(self._get_estimator(), "predict_proba"):
method_mapping.add(caller="predict_log_proba", callee="predict_proba")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is because we need to test it with a classifier that does not have predict_log_proba, but has predict_proba.

I have added an "outsourced" test just as I did for the case when the sub-classifier doesn't have predict_proba.

@@ -0,0 +1,4 @@
- :class:`ensemble.BaggingClassifier` and :class:`ensemble.BaggingRegressor` now support
metadata routing through their `predict`, `predict_proba`, `predict_log_proba` and
`decision_function` methods and pass `**predict_params` to the underlying estimators.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
`decision_function` methods and pass `**predict_params` to the underlying estimators.
`decision_function` methods and pass `**params` to the underlying estimators.

@@ -1279,6 +1403,14 @@ def predict(self, X):
reset=False, 10000
)

_raise_for_params(params, self, "predict")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd put all these _raise_form_params calls as the first thing in these methods, so that the user knows as soon as possible that they need to change their script, instead of fitting and fixing their data, and then suddenly getting this error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that makes sense, especially if fitting takes a while. I have put all the _raise_for_params() first.

)
bagging = BaggingClassifier(estimator=estimator)
bagging.fit(X, y)
bagging.predict(X=np.array([[1, 1], [1, 3], [0, 2]]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason we still have a codecov warning is that here you're not calling predict_log_proba and/or predict_proba. I think we need to test all the 4 methods, and also need to pass metadata here to actually record something and activate the routing mechanism.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit a3acb0. Thanks for your help with it!

Copy link
Member
@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outstanding!

@adrinjalali adrinjalali merged commit 0372d5e into scikit-learn:main Mar 18, 2025
33 checks passed
@github-project-automation github-project-automation bot moved this from In Progress to Done in Metadata routing Mar 18, 2025
@StefanieSenger StefanieSenger deleted the metadata_routing_bagging branch March 19, 2025 08:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Metadata Routing all issues related to metadata routing, slep006, sample props module:ensemble
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

Add metadata routing params support in the predict method of BaggingClassifier/Regressor
3 participants
0