-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
KFold splitter falsely claims it supports groups #22848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
The fix would involve either dynamically rewriting the docstring of the parent class or overriding |
The second option seems more like our usual practices :) |
I would be up for taking this issue if it is okay with you. Just to be sure I understand this issue correctly: We want to overwrite the split method in the KFold cross-validator. The function signature should go from Additionally, I have looked at the unit tests. It seems that some of the cross-validators are tested within the same test, where the @ignore_warnings
def test_cross_validator_with_default_params():
n_samples = 4
n_unique_groups = 4
n_splits = 2
p = 2
n_shuffle_splits = 10 # (the default value)
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
X_1d = np.array([1, 2, 3, 4])
y = np.array([1, 1, 2, 2])
groups = np.array([1, 2, 3, 4])
loo = LeaveOneOut()
lpo = LeavePOut(p)
kf = KFold(n_splits)
skf = StratifiedKFold(n_splits)
lolo = LeaveOneGroupOut()
lopo = LeavePGroupsOut(p)
ss = ShuffleSplit(random_state=0)
ps = PredefinedSplit([1, 1, 2, 2]) # n_splits = np of unique folds = 2
sgkf = StratifiedGroupKFold(n_splits)
loo_repr = "LeaveOneOut()"
lpo_repr = "LeavePOut(p=2)"
kf_repr = "KFold(n_splits=2, random_state=None, shuffle=False)"
skf_repr = "StratifiedKFold(n_splits=2, random_state=None, shuffle=False)"
lolo_repr = "LeaveOneGroupOut()"
lopo_repr = "LeavePGroupsOut(n_groups=2)"
ss_repr = (
"ShuffleSplit(n_splits=10, random_state=0, test_size=None, train_size=None)"
)
ps_repr = "PredefinedSplit(test_fold=array([1, 1, 2, 2]))"
sgkf_repr = "StratifiedGroupKFold(n_splits=2, random_state=None, shuffle=False)"
n_splits_expected = [
n_samples,
comb(n_samples, p),
n_splits,
n_splits,
n_unique_groups,
comb(n_unique_groups, p),
n_shuffle_splits,
2,
n_splits,
]
for i, (cv, cv_repr) in enumerate(
zip(
[loo, lpo, kf, skf, lolo, lopo, ss, ps, sgkf],
[
loo_repr,
lpo_repr,
kf_repr,
skf_repr,
lolo_repr,
lopo_repr,
ss_repr,
ps_repr,
sgkf_repr,
],
)
):
# Test if get_n_splits works correctly
if isinstance(cv, KFold):
assert n_splits_expected[i] == cv.get_n_splits(X, y)
else:
assert n_splits_expected[i] == cv.get_n_splits(X, y, groups)
# Test if the cross-validator works as expected even if
# the data is 1d
if isinstance(cv, KFold):
np.testing.assert_equal(
list(cv.split(X, y)), list(cv.split(X_1d, y))
)
else:
np.testing.assert_equal(
list(cv.split(X, y, groups)), list(cv.split(X_1d, y, groups))
)
# Test that train, test indices returned are integers
if isinstance(cv, KFold):
for train, test in cv.split(X, y):
assert np.asarray(train).dtype.kind == "i"
assert np.asarray(test).dtype.kind == "i"
else:
for train, test in cv.split(X, y, groups):
assert np.asarray(train).dtype.kind == "i"
assert np.asarray(test).dtype.kind == "i"
# Test if the repr works without any errors
assert cv_repr == repr(cv)
# ValueError for get_n_splits methods
msg = "The 'X' parameter should not be None."
with pytest.raises(ValueError, match=msg):
loo.get_n_splits(None, y, groups)
with pytest.raises(ValueError, match=msg):
lpo.get_n_splits(None, y, groups) |
No, we can not change the signature since it will break backward compatibility. class KFold(...):
def split(self, X, y=None, groups=None):
"""...
groups : array-like of shape (n_samples,), default=None
Not used, present for API consistency.
...
"""
super().split(self, X, y=y) |
Makes sense. As far as I can tell, the following cross-validators should have their docstring for the
I suggest updating it to the following, as this is how it has been updated in other classes. class KFold(...):
def split(self, X, y=None, groups=None):
"""...
groups : object
Always ignored, exists for compatibility.
...
""" Also, do we want to update the docstring for both the |
Hi all, I recently came across this issue and wanted to bump it. I would recommend not only changing the docstring, but also throwing an error or at least a warning whenever a user provided anything other than I recommend the additional warning because even when the documentation is clear (like in StratifiedKFold) the split functions will still silently ignore user's input which might lead to unintended consequences. This seems like a good place to have I can offer to attempt to push some code to resolve this issue, if you'd like. Would that be helpful? |
@thomasjpfan : (taken from https://github.com/scikit-learn/scikit-learn/pull/22765/files#r825487547)
I haven't looked at how to proceed here. It certainly shouldn't claim it supports
groups
in the docs, and we need to check how to fix the API if we can.The text was updated successfully, but these errors were encountered: