10000 TST improve error message on _more_tags and _get_tags deprecation by adrinjalali · Pull Request #29801 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

TST improve error message on _more_tags and _get_tags deprecation #29801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 12, 2024

Conversation

adrinjalali
Copy link
Member

This PR improves the error message so that estimator developers have a better idea of how they can support multiple sklearn versions (ref: #29677 (comment))

Example pytest run:

$ pytest ~/Projects/sklearn/bugs/1.py
==================================================================================================================== test session starts =====================================================================================================================
platform linux -- Python 3.12.3, pytest-8.2.0, pluggy-1.5.0
rootdir: /home/adrin
plugins: xdist-3.5.0, anyio-4.4.0
collected 33 items                                                                                                                                                                                                                                           

Projects/sklearn/bugs/1.py ..................F..............                                                                                                                                                                                           [100%]

========================================================================================================================== FAILURES ==========================================================================================================================
_______________________________________________________________________________________________ test_my_estimator[MyEstimator()-check_estimator_tags_renamed] ________________________________________________________________________________________________

estimator = MyEstimator(), check = functools.partial(<function check_estimator_tags_renamed at 0x74f6290a14e0>, 'MyEstimator')

    @parametrize_with_checks([MyEstimator()])
    def test_my_estimator(estimator, check):
>       check(estimator)

Projects/sklearn/bugs/1.py:30: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

name = 'MyEstimator', estimator_orig = MyEstimator()

    def check_estimator_tags_renamed(name, estimator_orig):
        help = """_more_tags() was removed in 1.6. Please use __sklearn_tags__ instead.
    Use the following template to migrate if you wish to support multiple versions of
    scikit-learn:
    
    import sklearn
    from packaging import version
    from sklearn.base import BaseEstimator
    from sklearn.utils.metaestimators import available_if
    
    def check_version(estimator):
        return version.parse(sklearn.__version__) < version.parse("1.6.dev")
    
    class MyEstimator(BaseEstimator):
        ...
    
        @available_if(check_version)
        def {tags_func}(self):
            tags = dict(key=value, ...)
            return tags
    
        def __sklearn_tags__(self):
            tags = super().__sklearn_tags__()
            tags.some_key = False
            return tags
        ...
    """
    
>       assert not hasattr(estimator_orig, "_more_tags"), help.format(
            tags_func="_more_tags"
        )
E       AssertionError: _more_tags() was removed in 1.6. Please use __sklearn_tags__ instead.
E       Use the following template to migrate if you wish to support multiple versions of
E       scikit-learn: 
E       
E       import sklearn
E       from packaging import version
E       from sklearn.base import BaseEstimator
E       from sklearn.utils.metaestimators import available_if
E       
E       def check_version(estimator):
E           return version.parse(sklearn.__version__) < version.parse("1.6.dev")
E       
E       class MyEstimator(BaseEstimator):
E           ...
E       
E           @available_if(check_version)
E           def _more_tags(self):
E               tags = dict(key=value, ...)
E               return tags
E           
E           def __sklearn_tags__(self):
E               tags = super().__sklearn_tags__()
E               tags.some_key = False
E               return tags
E           ...

Projects/gh/me/scikit-learn/sklearn/utils/estimator_checks.py:3963: AssertionError
====================================================================================================================== warnings summary ======================================================================================================================
Projects/gh/me/scikit-learn/sklearn/utils/_array_api.py:123
  /home/adrin/Projects/gh/me/scikit-learn/sklearn/utils/_array_api.py:123: UserWarning: Some scikit-learn array API features might rely on enabling SciPy's own support for array API to function properly. Please set the SCIPY_ARRAY_API=1 environment variable before importing sklearn or scipy. More details at: https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================================================================== short test summary info ===================================================================================================================
FAILED Projects/sklearn/bugs/1.py::test_my_estimator[MyEstimator()-check_estimator_tags_renamed] - AssertionError: _more_tags() was removed in 1.6. Please use __sklearn_tags__ instead.
========================================================================================================== 1 failed, 32 passed, 1 warning in 1.28s ===========================================================================================================

cc @adam2392 @glemaitre @larsoner

Copy link
github-actions bot commented Sep 7, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 6ca460c. Link to the linter CI: here

@ogrisel ogrisel added the Developer API Third party developer API related label Sep 9, 2024
from sklearn.utils.metaestimators import available_if

def check_version(estimator):
return version.parse(sklearn.__version__) < version.parse("1.6.dev")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we can get the base version to avoid to have to deal with the .dev

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be nice to have that in the user guide as well for the transition.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm reluctant to put this in the docs since it's about private API. I think this error message in reality is where users would see it, rather than the docs.

And as for the version, the issue is that 1.5.2 < 1.6.dev < 1.6 and the change really happens in 1.6.dev. If we use any other version, we might need to change it. I've also seen code in other repos where they reference some unreleased version of python.

Copy link
Member
@thomasjpfan thomasjpfan Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe just to include _more_tags without checking the scikit-learn version? Specifically, can scikit-learn 1.6 ignore _more_tags if __sklearn_tags__ is defined?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would guess the issue is that users would define _more_tags (but not __sklearn_tags__), but not have things work as expected once they transition to v1.6?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to have both methods. So, if the user defines both _more_tags and __sklearn_tags__ they most probably know what they're doing.

I've changed the PR to reflect that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If they can safely define both, then can we remove the need for available_if? The story becomes "If you want to support multiple scikit-learn versions, define both."

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, simplified the message.

@glemaitre glemaitre self-requested a review September 12, 2024 13:06
Copy link
Member
@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM also. Thanks @adrinjalali

@glemaitre glemaitre merged commit d4ab9ed into scikit-learn:main Sep 12, 2024
30 checks passed
@adrinjalali adrinjalali deleted the tags/message branch September 12, 2024 13:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0