8000 Erroneous optional status for y parameter in RepeatedStratifiedKFold.split · Issue #29369 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
Erroneous optional status for y parameter in RepeatedStratifiedKFold.split #29369
Closed
@erwanm

Description

@erwanm

Describe the bug

For context, there is a small difference in the split function between the variants of the KFold class:

In class sklearn.model_selection.KFold, the split function has an optional parameter y (same for class sklearn.model_selection.RepeatedKFold).

In class sklearn.model_selection.StratifiedKFold, the same parameter y is mandatory because "Stratification is done based on the y labels". As expected, omitting y when calling split causes an explicit error:

TypeError: StratifiedKFold.split() missing 1 required positional argument: 'y'

However sklearn.model_selection.RepeatedStratifiedKFold is also a stratified variant which requires parameter y, but the parameter is erroneously left as optional. This seems due to the fact this is implemented through a general class _UnsupportedGroupCVMixin. As a result, not providing y causes an unclear error message inconsistent with the one for StratifiedKFold in the same context.

Steps/Code to Reproduce

from sklearn.model_selection import KFold, RepeatedKFold, StratifiedKFold, RepeatedStratifiedKFold

x = [ 'a '] * 100
y = [ 0 ] * 90 + [ 1 ] * 10

# y is NOT optional -> error 'missing 1 required positional argument: 'y'' as expected 
for i, (train, test) in enumerate(StratifiedKFold(n_splits=2).split(x)):
    print('i =', i, '. train =', train)
    print('i =', i,  '. test =', test)

# y is supposed to be optional according to documentation, but not providing it causes an unclear error message  
for i, (train, test) in enumerate(RepeatedStratifiedKFold(n_splits=2, n_repeats=3).split(x)):
    print('i =', i, '. train =', train)
    print('i =', i,  '. test =', test)

Expected Results

Same error message 'missing 1 required positional argument' in both cases.

Actual Results

expected error in first example, erroneous behavior in second example.

Versions

System:
    python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
executable: /usr/bin/python3
   machine: Linux-6.1.85+-x86_64-with-glibc2.35

Python dependencies:
      sklearn: 1.2.2
          pip: 23.1.2
   setuptools: 67.7.2
        numpy: 1.25.2
        scipy: 1.11.4
       Cython: 3.0.10
       pandas: 2.0.3
   matplotlib: 3.7.1
       joblib: 1.4.2
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 2
         prefix: libopenblas
       filepath: /usr/local/lib/python3.10/dist-packages/numpy.libs/libopenblas64_p-r0-5007b62f.3.23.dev.so
        version: 0.3.23.dev
threading_layer: pthreads
   architecture: Haswell

       user_api: openmp
   internal_api: openmp
    num_threads: 2
         prefix: libgomp
       filepath: /usr/local/lib/python3.10/dist-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
        version: None

       user_api: blas
   internal_api: openblas
    num_threads: 2
         prefix: libopenblas
       filepath: /usr/local/lib/python3.10/dist-packages/scipy.libs/libopenblasp-r0-23e5df77.3.21.dev.so
        version: 0.3.21.dev
threading_layer: pthreads
   architecture: Haswell

Metadata

Metadata

Assignees

Labels

BugEasyWell-defined and straightforward way to resolve

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0