-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
API Revamp estimator tags #29677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
API Revamp estimator tags #29677
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bunch of changes to be consistent for the documentation style mainly.
It's a green CI @glemaitre |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So LGTM.
I had a quick look at the docs and maybe I missed it, but as a consumer of sklearn that subclasses BaseEstimator I'm not sure how to adapt my code. Fore example, I currently have classes that have stuff like:
How do I adapt these classes in a way that is backward compatible with previous sklearn versions? I can't just leave
But I can't remove it because then it won't be backward compatible. One solution would be to add an opt-in to have the validator ignore this attribute being present, or maybe change it to ensure that if Or is there a simpler way for me to adjust my code? |
@larsoner you can leave Something like this: import numpy as np
import sklearn
from packaging import version
from sklearn.base import BaseEstimator
from sklearn.utils.estimator_checks import parametrize_with_checks
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import check_is_fitted, validate_data
def check_version(estimator):
return version.parse(sklearn.__version__) < version.parse("1.6.dev")
class MyEstimator(BaseEstimator):
@available_if(check_version)
def _more_tags(self):
return {"_skip_test": False}
def fit(self, X, y=None):
validate_data(self, X, y)
return self
def predict(self, X):
check_is_fitted(self)
validate_data(self, X, reset=False)
return np.zeros(X.shape[0], dtype=int)
@parametrize_with_checks([MyEstimator()])
def test_my_estimator(estimator, check):
check(estimator) |
Closes #22606
Closes #20804
This PR revamps estimator tags, puts them in dataclasses, and is based on #22606
High level changes from this PR:
(from #22606):
_get_tags
and_more_tags
and introduce__sklearn_tags__
this PR:
stateless
is removed and now we only userequires_fit
. The two were redundant.only_binary
is now replaced withmulti_class
multioutput_only
is removed and now we havemulti_output
andsingle_output
get_tags
,default_tags
, andTags
are put into public API