8000 `y`, and `groups` parameters to`StratifiedGroupKFold.split()` are optional · Issue #30742 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

y, and groups parameters toStratifiedGroupKFold.split() are optional #30742

@Teagum

Description

@Teagum

Describe the bug

StratifiedGroupKFold.split has the signature (self, X, y=None, groups=None) indicating that both y, and groups may not be specified when calling split.

However, omitting only groups results in TypeError: iteration over a 0-d array. Also, when omitting both y and groups, or only y the result is ValueError: Supported target types are: ('binary', 'multiclass'). Got 'unknown' instead. This indicates, contrary to the signature, that y and `groups are required and not optional.

I would instead expect consisted behavior with e.g. StratifiedKFold, where the y parameter to split is not optional.

StratifiedKFold and StratifiedGroupKFold both inherit from _BaseKFold, which provides .split. However StratifiedKFold implements its own split method, instead of using _BaseKFold like StratifiedGroupKFold does.

Steps/Code to Reproduce

import numpy as np
from sklearn.model_selection import StratifiedGroupKFold

rng = np.random.default_rng()

X = rng.normal(size=(10, 3))
y = np.concatenate((np.ones(5, dtype=int), 
98EF
np.zeros(5, dtype=int)))
g = np.tile([1, 0], 5)

sgkf = StratifiedGroupKFold(n_splits=5)
next(sgkf.split(X=X, y=y, groups=None))           # TypeError

sgkf = StratifiedGroupKFold(n_splits=5)
next(sgkf.split(X=X, y=None, groups=None))    # ValueError

sgkf = StratifiedGroupKFold(n_splits=5)
next(sgkf.split(X=X, y=None, groups=g))          # ValueError

Expected Results

Either no error if y, groups, or both are not specified. Or remove the default of None for both parameters from the function signature.

Actual Results

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[2], line 2
      1 sgkf = StratifiedGroupKFold(n_splits=5)
----> 2 next(sgkf.split(X=X, y=y, groups=None))    # TypeError

File /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:411, in _BaseKFold.split(self, X, y, groups)
    403 if self.n_splits > n_samples:
    404     raise ValueError(
    405         (
    406             "Cannot have number of splits n_splits={0} greater"
    407             " than the number of samples: n_samples={1}."
    408         ).format(self.n_splits, n_samples)
    409     )
--> 411 for train, test in super().split(X, y, groups):
    412     yield train, test

File /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:142, in BaseCrossValidator.split(self, X, y, groups)
    140 X, y, groups = indexable(X, y, groups)
    141 indices = np.arange(_num_samples(X))
--> 142 for test_index in self._iter_test_masks(X, y, groups):
    143     train_index = indices[np.logical_not(test_index)]
    144     test_index = indices[test_index]

File /<PATH>/ib/python3.12/site-packages/sklearn/model_selection/_split.py:154, in BaseCrossValidator._iter_test_masks(self, X, y, groups)
    149 def _iter_test_masks(self, X=None, y=None, groups=None):
    150     """Generates boolean masks corresponding to test sets.
    151 
    152     By default, delegates to _iter_test_indices(X, y, groups)
    153     """
--> 154     for test_index in self._iter_test_indices(X, y, groups):
    155         test_mask = np.zeros(_num_samples(X), dtype=bool)
    156         test_mask[test_index] = True

File /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:1035, in StratifiedGroupKFold._iter_test_indices(self, X, y, groups)
   1031 _, groups_inv, groups_cnt = np.unique(
   1032     groups, return_inverse=True, return_counts=True
   1033 )
   1034 y_counts_per_group = np.zeros((len(groups_cnt), n_classes))
-> 1035 for class_idx, group_idx in zip(y_inv, groups_inv):
   1036     y_counts_per_group[group_idx, class_idx] += 1
   1038 y_counts_per_fold = np.zeros((self.n_splits, n_classes))

TypeError: iteration over a 0-d array

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[13], line 2
      1 sgkf = StratifiedGroupKFold(n_splits=5)
----> 2 next(sgkf.split(X=X, y=None, groups=g))

File  /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:411, in _BaseKFold.split(self, X, y, groups)
    403 if self.n_splits > n_samples:
    404     raise ValueError(
    405         (
    406             "Cannot have number of splits n_splits={0} greater"
    407             " than the number of samples: n_samples={1}."
    408         ).format(self.n_splits, n_samples)
    409     )
--> 411 for train, test in super().split(X, y, groups):
    412     yield train, test

File  /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:142, in BaseCrossValidator.split(self, X, y, groups)
    140 X, y, groups = indexable(X, y, groups)
    141 indices = np.arange(_num_samples(X))
--> 142 for test_index in self._iter_test_masks(X, y, groups):
    143     train_index = indices[np.logical_not(test_index)]
    144     test_index = indices[test_index]

File  /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:154, in BaseCrossValidator._iter_test_masks(self, X, y, groups)
    149 def _iter_test_masks(self, X=None, y=None, groups=None):
    150     """Generates boolean masks corresponding to test sets.
    151 
    152     By default, delegates to _iter_test_indices(X, y, groups)
    153     """
--> 154     for test_index in self._iter_test_indices(X, y, groups):
    155         test_mask = np.zeros(_num_samples(X), dtype=bool)
    156         test_mask[test_index] = True

File  /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:1008, in StratifiedGroupKFold._iter_test_indices(self, X, y, groups)
   1006 allowed_target_types = ("binary", "multiclass")
   1007 if type_of_target_y not in allowed_target_types:
-> 1008     raise ValueError(
   1009         "Supported target types are: {}. Got {!r} instead.".format(
   1010             allowed_target_types, type_of_target_y
   1011         )
   1012     )
   1014 y = column_or_1d(y)
   1015 _, y_inv, y_cnt = np.unique(y, return_inverse=True, return_counts=True)

ValueError: Supported target types are: ('binary', 'multiclass'). Got 'unknown' instead.

Versions

System:
    python: 3.12.4 (main, Jul 23 2024, 09:14:16) [GCC 14.1.1 20240522]
executable: /<PATH>/bin/python
   machine: Linux-6.12.9-arch1-1-x86_64-with-glibc2.40

Python dependencies:
      sklearn: 1.6.1
          pip: 24.3.1
   setuptools: None
        numpy: 2.2.2
        scipy: 1.15.1
       Cython: None
       pandas: 2.2.3
   matplotlib: 3.10.0
       joblib: 1.4.2
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 8
         prefix: libscipy_openblas
       filepath: /<PATH>/lib/python3.12/site-packages/numpy.libs/libscipy_openblas64_-6bb31eeb.so
        version: 0.3.28
threading_layer: pthreads
   architecture: SkylakeX

       user_api: blas
   internal_api: openblas
    num_threads: 8
         prefix: libscipy_openblas
       filepath: /<PATH>/lib/python3.12/site-packages/scipy.libs/libscipy_openblas-68440149.so
        version: 0.3.28
threading_layer: pthreads
   architecture: SkylakeX

       user_api: openmp
   internal_api: openmp
    num_threads: 8
         prefix: libgomp
       filepath: /<PATH>/lib/python3.12/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
        version: None

Metadata

Metadata

Assignees

No one assigned

    Labels

    Validationrelated to input validation

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0