10000 API Revamp estimator tags by adrinjalali · Pull Request #29677 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

API Revamp estimator tags #29677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 63 commits into from
Sep 4, 2024
Merged

Conversation

adrinjalali
Copy link
Member
@adrinjalali adrinjalali commented Aug 15, 2024

Closes #22606
Closes #20804

This PR revamps estimator tags, puts them in dataclasses, and is based on #22606

High level changes from this PR:

(from #22606):

  • replace MRO mechanism with inheritance
  • remove _get_tags and _more_tags and introduce __sklearn_tags__

this PR:

  • dataclasses are introduced to store the tags, and they're scoped in a few dataclasses. This helps users with auto-complete as well.
  • stateless is removed and now we only use requires_fit. The two were redundant.
  • only_binary is now replaced with multi_class
  • multioutput_only is removed and now we have multi_output and single_output
  • get_tags, default_tags, and Tags are put into public API

Copy link
github-actions bot commented Aug 15, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 88af896. Link to the linter CI: here

@glemaitre glemaitre self-requested a review September 3, 2024 14:44
Copy link
Member
@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bunch of changes to be consistent for the documentation style mainly.

@adrinjalali
Copy link
Member Author

It's a green CI @glemaitre

Copy link
Member
@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So LGTM.

@glemaitre glemaitre merged commit e04142c into scikit-learn:main Sep 4, 2024
30 checks passed
@adrinjalali adrinjalali deleted the estimator-tags branch September 5, 2024 08:38
@larsoner
Copy link
Contributor
larsoner commented Sep 6, 2024

I had a quick look at the docs and maybe I missed it, but as a consumer of sklearn that subclasses BaseEstimator I'm not sure how to adapt my code. Fore example, I currently have classes that have stuff like:

    def _more_tags(self):
        return {"no_validation": True}

How do I adapt these classes in a way that is backward compatible with previous sklearn versions? I can't just leave _more_tags in there to be picked up by older versions of sklearn because then I fail validation:

mne/decoding/tests/test_search_light.py:373: in test_sklearn_compliance
    check(est)
../virtualenvs/base/lib/python3.12/site-packages/sklearn/utils/estimator_checks.py:3893: in check_estimator_tags_renamed
    assert not hasattr(estimator_orig, "_more_tags"), (
E   AssertionError: ('_more_tags() was removed in 1.6. Please use __sklearn_tags__ instead.',)

But I can't remove it because then it won't be backward compatible.

One solution would be to add an opt-in to have the validator ignore this attribute being present, or maybe change it to ensure that if _more_tags is present then __sklearn_tags__ is also present.

Or is there a simpler way for me to adjust my code?

@adrinjalali
Copy link
Member Author

@larsoner you can leave _more_tags there with an @available_if decorator, and the check would be the scikit-learn version.

Something like this:

import numpy as np
import sklearn
from packaging import version
from sklearn.base import BaseEstimator
from sklearn.utils.estimator_checks import parametrize_with_checks
from sklearn.utils.metaestimators import available_if

from sklearn.utils.validation import check_is_fitted, validate_data

def check_version(estimator):
    return version.parse(sklearn.__version__) < version.parse("1.6.dev")

class MyEstimator(BaseEstimator):
    @available_if(check_version)
    def _more_tags(self):
        return {"_skip_test": False}
    
    def fit(self, X, y=None):
        validate_data(self, X, y)
        return self
    
    def predict(self, X):
        check_is_fitted(self)
        validate_data(self, X, reset=False)
        return np.zeros(X.shape[0], dtype=int)
    
@parametrize_with_checks([MyEstimator()])
def test_my_estimator(estimator, check):
    check(estimator)

ghost

This comment was marked as off-topic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Developer API Third party developer API related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Revisting the tags interface
7 participants
0