8000 KFold splitter falsely claims it supports groups · Issue #22848 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
8000

KFold splitter falsely claims it supports groups #22848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
adrinjalali opened this issue Mar 15, 2022 · 6 comments · Fixed by #28210
Closed

KFold splitter falsely claims it supports groups #22848

adrinjalali opened this issue Mar 15, 2022 · 6 comments · Fixed by #28210
Labels
Documentation Moderate Anything that requires some knowledge of conventions and best practices module:model_selection

Comments

@adrinjalali
Copy link
Member
adrinjalali commented Mar 15, 2022

@thomasjpfan : (taken from https://github.com/scikit-learn/scikit-learn/pull/22765/files#r825487547)

My mind is a little blown. So the current API for KFold.split has a groups parameter it never uses. Because of the inheritance structure, the API docs for KFold.split says the group labels are used in KFold.

(This comment is not actionable)

I haven't looked at how to proceed here. It certainly shouldn't claim it supports groups in the docs, and we need to check how to fix the API if we can.

@github-actions github-actions bot added the Needs Triage Issue requires triage label Mar 15, 2022
@thomasjpfan thomasjpfan added Documentation and removed Needs Triage Issue requires triage labels Mar 25, 2022
@thomasjpfan
Copy link
Member

The fix would involve either dynamically rewriting the docstring of the parent class or overriding split in subclasses that do not use groups and update the docstring.

@thomasjpfan thomasjpfan added the Moderate Anything that requires some knowledge of conventions and best practices label Mar 25, 2022
@jeremiedbb
Copy link
Member

The second option seems more like our usual practices :)

@TheisFerre
Copy link
Contributor

I would be up for taking this issue if it is okay with you.

Just to be sure I understand this issue correctly: We want to overwrite the split method in the KFold cross-validator. The function signature should go from split(self, X, y=None, groups=None) to split(self, X, y=None).

Additionally, I have looked at the unit tests. It seems that some of the cross-validators are tested within the same test, where the groups argument is given to them all. What do you think is the best way of updating the tests. With if-statements? (see example below where I have added if isinstance(cv, KFold): multiple places)

@ignore_warnings
def test_cross_validator_with_default_params():
    n_samples = 4
    n_unique_groups = 4
    n_splits = 2
    p = 2
    n_shuffle_splits = 10  # (the default value)

    X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
    X_1d = np.array([1, 2, 3, 4])
    y = np.array([1, 1, 2, 2])
    groups = np.array([1, 2, 3, 4])
    loo = LeaveOneOut()
    lpo = LeavePOut(p)
    kf = KFold(n_splits)
    skf = StratifiedKFold(n_splits)
    lolo = LeaveOneGroupOut()
    lopo = LeavePGroupsOut(p)
    ss = ShuffleSplit(random_state=0)
    ps = PredefinedSplit([1, 1, 2, 2])  # n_splits = np of unique folds = 2
    sgkf = StratifiedGroupKFold(n_splits)

    loo_repr = "LeaveOneOut()"
    lpo_repr = "LeavePOut(p=2)"
    kf_repr = "KFold(n_splits=2, random_state=None, shuffle=False)"
    skf_repr = "StratifiedKFold(n_splits=2, random_state=None, shuffle=False)"
    lolo_repr = "LeaveOneGroupOut()"
    lopo_repr = "LeavePGroupsOut(n_groups=2)"
    ss_repr = (
        "ShuffleSplit(n_splits=10, random_state=0, test_size=None, train_size=None)"
    )
    ps_repr = "PredefinedSplit(test_fold=array([1, 1, 2, 2]))"
    sgkf_repr = "StratifiedGroupKFold(n_splits=2, random_state=None, shuffle=False)"

    n_splits_expected = [
        n_samples,
        comb(n_samples, p),
        n_splits,
        n_splits,
        n_unique_groups,
        comb(n_unique_groups, p),
        n_shuffle_splits,
        2,
        n_splits,
    ]

    for i, (cv, cv_repr) in enumerate(
        zip(
            [loo, lpo, kf, skf, lolo, lopo, ss, ps, sgkf],
            [
                loo_repr,
                lpo_repr,
                kf_repr,
                skf_repr,
                lolo_repr,
                lopo_repr,
                ss_repr,
                ps_repr,
                sgkf_repr,
            ],
        )
    ):
        # Test if get_n_splits works correctly
        if isinstance(cv, KFold):
            assert n_splits_expected[i] == cv.get_n_splits(X, y)
        else:
            assert n_splits_expected[i] == cv.get_n_splits(X, y, groups)

        # Test if the cross-validator works as expected even if
        # the data is 1d
        if isinstance(cv, KFold):
            np.testing.assert_equal(
                list(cv.split(X, y)), list(cv.split(X_1d, y))
            )
        else:
            np.testing.assert_equal(
                list(cv.split(X, y, groups)), list(cv.split(X_1d, y, groups))
            )
        # Test that train, test indices returned are integers
        if isinstance(cv, KFold):
            for train, test in cv.split(X, y):
                assert np.asarray(train).dtype.kind == "i"
                assert np.asarray(test).dtype.kind == "i"
        else:
            for train, test in cv.split(X, y, groups):
                assert np.asarray(train).dtype.kind == "i"
                assert np.asarray(test).dtype.kind == "i"

        # Test if the repr works without any errors
        assert cv_repr == repr(cv)

    # ValueError for get_n_splits methods
    msg = "The 'X' parameter should not be None."
    with pytest.raises(ValueError, match=msg):
        loo.get_n_splits(None, y, groups)
    with pytest.raises(ValueError, match=msg):
        lpo.get_n_splits(None, y, groups)

@thomasjpfan
Copy link
Member
thomasjpfan commented Mar 27, 2022

The function signature should go from split(self, X, y=None, groups=None) to split(self, X, y=None).

No, we can not change the signature since it will break backward compatibility. KFold should have the same signature, but the docstring for groups changes:

class KFold(...):
    def split(self, X, y=None, groups=None):
        """...

        groups : array-like of shape (n_samples,), default=None
            Not used, present for API consistency.
        ...
        """
        super().split(self, X, y=y)

@TheisFerre
Copy link
Contributor

Makes sense.

As far as I can tell, the following cross-validators should have their docstring for the groups argument changed.

  • KFold

  • LeaveOneOut

  • LeavePOut

  • RepeatedKFold

  • RepeatedStratifiedKFold

  • ShuffleSplit

I suggest updating it to the following, as this is how it has been updated in other classes.

class KFold(...):
    def split(self, X, y=None, groups=None):
        """...

        groups : object
            Always ignored, exists for compatibility.
        ...
        """

Also, do we want to update the docstring for both the get_n_splits and split and method?

@BarkleyBG
Copy link
Contributor

Hi all, I recently came across this issue and wanted to bump it. I would recommend not only changing the docstring, but also throwing an error or at least a warning whenever a user provided anything other than groups=None for the non-Group-samplers (KFold, StratifiedKFold, ShuffleSplit, etc).

I recommend the additional warning because even when the documentation is clear (like in StratifiedKFold) the split functions will still silently ignore user's input which might lead to unintended consequences. This seems like a good place to have if groups is not None: warnings.warn("this is ignored; see documentation") or even throw an error.

I can offer to attempt to push some code to resolve this issue, if you'd like. Would that be helpful?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Documentation Moderate Anything that requires some knowledge of conventions and best practices module:model_selection
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants
0