Description
Describe the bug
For context, there is a small difference in the split
function between the variants of the KFold
class:
In class sklearn.model_selection.KFold
, the split function has an optional parameter y
(same for class sklearn.model_selection.RepeatedKFold
).
In class sklearn.model_selection.StratifiedKFold
, the same parameter y
is mandatory because "Stratification is done based on the y labels". As expected, omitting y
when calling split
causes an explicit error:
TypeError: StratifiedKFold.split() missing 1 required positional argument: 'y'
However sklearn.model_selection.RepeatedStratifiedKFold
is also a stratified variant which requires parameter y
, but the parameter is erroneously left as optional. This seems due to the fact this is implemented through a general class _UnsupportedGroupCVMixin
. As a result, not providing y
causes an unclear error message inconsistent with the one for StratifiedKFold
in the same context.
Steps/Code to Reproduce
from sklearn.model_selection import KFold, RepeatedKFold, StratifiedKFold, RepeatedStratifiedKFold
x = [ 'a '] * 100
y = [ 0 ] * 90 + [ 1 ] * 10
# y is NOT optional -> error 'missing 1 required positional argument: 'y'' as expected
for i, (train, test) in enumerate(StratifiedKFold(n_splits=2).split(x)):
print('i =', i, '. train =', train)
print('i =', i, '. test =', test)
# y is supposed to be optional according to documentation, but not providing it causes an unclear error message
for i, (train, test) in enumerate(RepeatedStratifiedKFold(n_splits=2, n_repeats=3).split(x)):
print('i =', i, '. train =', train)
print('i =', i, '. test =', test)
Expected Results
Same error message 'missing 1 required positional argument' in both cases.
Actual Results
expected error in first example, erroneous behavior in second example.
Versions
System:
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
executable: /usr/bin/python3
machine: Linux-6.1.85+-x86_64-with-glibc2.35
Python dependencies:
sklearn: 1.2.2
pip: 23.1.2
setuptools: 67.7.2
numpy: 1.25.2
scipy: 1.11.4
Cython: 3.0.10
pandas: 2.0.3
matplotlib: 3.7.1
joblib: 1.4.2
threadpoolctl: 3.5.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
num_threads: 2
prefix: libopenblas
filepath: /usr/local/lib/python3.10/dist-packages/numpy.libs/libopenblas64_p-r0-5007b62f.3.23.dev.so
version: 0.3.23.dev
threading_layer: pthreads
architecture: Haswell
user_api: openmp
internal_api: openmp
num_threads: 2
prefix: libgomp
filepath: /usr/local/lib/python3.10/dist-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
version: None
user_api: blas
internal_api: openblas
num_threads: 2
prefix: libopenblas
filepath: /usr/local/lib/python3.10/dist-packages/scipy.libs/libopenblasp-r0-23e5df77.3.21.dev.so
version: 0.3.21.dev
threading_layer: pthreads
architecture: Haswell