8000 Add metadata routing params support in the predict method of `BaggingClassifier/Regressor` · Issue #30808 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Add metadata routing params support in the predict method of BaggingClassifier/Regressor #30808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
aperezlebel opened this issue Feb 10, 2025 · 7 comments · Fixed by #30833
Closed
Labels
Metadata Routing all issues related to metadata routing, slep006, sample props New Feature

Comments

@aperezlebel
Copy link
Contributor

Describe the workflow you want to enable

Hello! I'm trying to use metadata routing with BaggingClassifier and BaggingRegressor however it is implemented for the fit method, not the predict one. I am wondering if there is a particular reason for not doing it on the predict function or if this is a feature that could be added. This would enable situations like the following, which currently gives an error:

import numpy as np
import sklearn
from sklearn import ensemble
from sklearn.base import BaseEstimator

sklearn.set_config(enable_metadata_routing=True)


class CustomEstimator(BaseEstimator):
    def fit(self, X, y, foo):
        return self

    def predict(self, X, bar):
        return np.zeros(X.shape[0])


estimator = CustomEstimator()
estimator.set_fit_request(foo=True)
estimator.set_predict_request(bar=True)
model = ensemble.BaggingRegressor(estimator)

n, p = 10, 2
rng = np.random.default_rng(0)
x = rng.random((n, p))
y = rng.integers(0, 2, n)

model.fit(x, y, foo=True)
model.predict(x, bar=True). # TypeError: BaggingRegressor.predict() got an unexpected keyword argument 'bar'

Describe your proposed solution

Similar to the fit method, something like:

if _routing_enabled():
    routed_params = process_routing(self, "predict", **predict_params)

However, I don't have enough understanding of the metadata routing implementation to know exactly what should be done.

Describe alternatives you've considered, if relevant

No response

Additional context

I tried to have a look at the history of PRs/Issues to find a discussion around this point, but could not find it in the PR introducing the metadata routing to these estimators (#28432).

@aperezlebel aperezlebel added Needs Triage Issue requires triage New Feature labels Feb 10, 2025
@adrinjalali
Copy link
Member

In the first round of implementation, we focused on fit, happy to have a PR implementing this for the predict method.

@adrinjalali adrinjalali added Metadata Routing all issues related to metadata routing, slep006, sample props and removed Needs Triage Issue requires triage labels Feb 11, 2025
@aperezlebel
Copy link
Contributor Author

Thank you for the feedback!

@adrinjalali
Copy link
Member

cc @StefanieSenger in case you fancy taking on this one.

@StefanieSenger
Copy link
Contributor

Yes, I'm happy to take care of it.

@StefanieSenger
Copy link
Contributor

Out of curiosity, what is your use case that needs routing metadata in predict methods, @aperezlebel?

@aperezlebel
Copy link
Contributor Author

I'm working with an estimator that uses feature names in addition to the feature values. When I wrap the estimator inside a BaggingClassifier or BaggingRegressor, I can't pass the feature names to the predict method, even when using a pandas dataframe because the bagging estimator converts it to a numpy array before passing it to the base estimator.

Thank you @StefanieSenger for taking care of this PR!

@StefanieSenger
Copy link
Contributor

Thank you @aperezlebel, I feel what you describe is something scikit-learn would offer without metadata routing. I didn't check how to do that precisely though. Anyways, seems the routing to predict will be in version 1.7.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Metadata Routing all issues related to metadata routing, slep006, sample props New Feature
Projects
Status: Done
3 participants
0