8000 `y`, and `groups` parameters to`StratifiedGroupKFold.split()` are optional · Issue #30742 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
y, and groups parameters toStratifiedGroupKFold.split() are optional #30742
Open
@Teagum

Description

@Teagum

Describe the bug

StratifiedGroupKFold.split has the signature (self, X, y=None, groups=None) indicating that both y, and groups may not be specified when calling split.

However, omitting only groups results in TypeError: iteration over a 0-d array. Also, when omitting both y and groups, or only y the result is ValueError: Supported target types are: ('binary', 'multiclass'). Got 'unknown' instead. This indicates, contrary to the signature, that y and `groups are required and not optional.

I would instead expect consisted behavior with e.g. StratifiedKFold, where the y parameter to split is not optional.

StratifiedKFold and StratifiedGroupKFold both inherit from _BaseKFold, which provides .split. However StratifiedKFold implements its own split method, instead of using _BaseKFold like StratifiedGroupKFold does.

Steps/Code to Reproduce

import numpy as np
from sklearn.model_selection import StratifiedGroupKFold

rng = np.random.default_rng()

X = rng.normal(size=(10, 3))
y = np.concatenate((np.ones(5, dtype=int), np.zeros(5, dtype=int)))
g = np.tile([1, 0], 5)

sgkf = StratifiedGroupKFold(n_splits=5)
next(sgkf.split(X=X, y=y, groups=None))           # TypeError

sgkf = StratifiedGroupKFold(n_splits=5)
next(sgkf.split(X=X, y=None, groups=None))    # ValueError

sgkf = StratifiedGroupKFold(n_splits=5)
next(sgkf.split(X=X, y=None, groups=g))          # ValueError

Expected Results

Either no error if y, groups, or both are not specified. Or remove the default of None for both parameters from the function signature.

Actual Results

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[2], line 2
      1 sgkf = StratifiedGroupKFold(n_splits=5)
----> 2 next(sgkf.split(X=X, y=y, groups=None))    # TypeError

File /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:411, in _BaseKFold.split(self, X, y, groups)
    403 if self.n_splits > n_samples:
    404     raise ValueError(
    405         (
    406             "Cannot have number of splits n_splits={0} greater"
    407             " than the number of samples: n_samples={1}."
    408         ).format(self.n_splits, n_samples)
    409     )
--> 411 for train, test in super().split(X, y, groups):
    412     yield train, test

File /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:142, in BaseCrossValidator.split(self, X, y, groups)
    140 X, y, groups = indexable(X, y, groups)
    141 indices = np.arange(_num_samples(X))
--> 142 for test_index in self._iter_test_masks(X, y, groups):
    143     train_index = indices[np.logical_not(test_index)]
    144     test_index = indices[test_index]

File /<PATH>/ib/python3.12/site-packages/sklearn/model_selection/_split.py:154, in BaseCrossValidator._iter_test_masks(self, X, y, groups)
    149 def _iter_test_masks(self, X=None, y=None, groups=None):
    150     """Generates boolean masks corresponding to test sets.
    151 
    152     By default, delegates to _iter_test_indices(X, y, groups)
    153     """
--> 154     for test_index in self._iter_test_indices(X, y, groups):
    155         test_mask = np.zeros(_num_samples(X), dtype=bool)
    156         test_mask[test_index] = True

File /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:1035, in StratifiedGroupKFold._iter_test_indices(self, X, y, groups)
   1031 _, groups_inv, groups_cnt = np.unique(
   1032     groups, return_inverse=True, return_counts=True
   1033 )
   1034 y_counts_per_group = np.zeros((len(groups_cnt), n_classes))
-> 1035 for class_idx, group_idx in zip(y_inv, groups_inv):
   1036     y_counts_per_group[group_idx, class_idx] += 1
   1038 y_counts_per_fold = np.zeros((self.n_splits, n_classes))

TypeError: iteration over a 0-d array

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[13], line 2
      1 sgkf = StratifiedGroupKFold(n_splits=5)
----> 2 next(sgkf.split(X=X, y=None, groups=g))

File  /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:411, in _BaseKFold.split(self, X, y, groups)
    403 if self.n_splits > n_samples:
    404     raise ValueError(
    405         (
    406             "Cannot have number of splits n_splits={0} greater"
    407             " than the number of samples: n_samples={1}."
    408         ).format(self.n_splits, n_samples)
    409     )
--> 411 for train, test in super().split(X, y, groups):
    412     yield train, test

File  /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:142, in BaseCrossValidator.split(self, X, y, groups)
    140 X, y, groups = indexable(X, y, groups)
    141 indices = np.arange(_num_samples(X))
--> 142 for test_index in self._iter_test_masks(X, y, groups):
    143     train_index = indices[np.logical_not(test_index)]
    144     test_index = indices[test_index]

File  /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:154, in BaseCrossValidator._iter_test_masks(self, X, y, groups)
    149 def _iter_test_masks(self, X=None, y=None, groups=None):
    150     """Generates boolean masks corresponding to test sets.
    151 
    152     By default, delegates to _iter_test_indices(X, y, groups)
    153     """
--> 154     for test_index in self._iter_test_indices(X, y, groups):
    155         test_mask = np.zeros(_num_samples(X), dtype=bool)
    156         test_mask[test_index] = True

File  /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:1008, in StratifiedGroupKFold._iter_test_indices(self, X, y, groups)
   1006 allowed_target_types = ("binary", "multiclass")
   1007 if type_of_target_y not in allowed_target_types:
-> 1008     raise ValueError(
   1009         "Supported target types are: {}. Got {!r} instead.".format(
   1010             allowed_target_types, type_of_target_y
   1011         )
   1012     )
   1014 y = column_or_1d(y)
   1015 _, y_inv, y_cnt = np.unique(y, return_inverse=True, return_counts=True)

ValueError: Supported target types are: ('binary', 'multiclass'). Got 'unknown' instead.

Versions

System:
    python: 3.12.4 (main, Jul 23 2024, 09:14:16) [GCC 14.1.1 20240522]
executable: /<PATH>/bin/python
   machine: Linux-6.12.9-arch1-1-x86_64-with-glibc2.40

Python dependencies:
      sklearn: 1.6.1
          pip: 24.3.1
   setuptools: None
        numpy: 2.2.2
        scipy: 1.15.1
       Cython: None
       pandas: 2.2.3
   matplotlib: 3.10.0
       joblib: 1.4.2
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 8
         prefix: libscipy_openblas
       filepath: /<PATH>/lib/python3.12/site-packages/numpy.libs/libscipy_openblas64_-6bb31eeb.so
        version: 0.3.28
threading_layer: pthreads
   architecture: SkylakeX

       user_api: blas
   internal_api: openblas
    num_threads: 8
         prefix: libscipy_openblas
       filepath: /<PATH>/lib/python3.12/site-packages/scipy.libs/libscipy_openblas-68440149.so
        version: 0.3.28
threading_layer: pthreads
   architecture: SkylakeX

       user_api: openmp
   internal_api: openmp
    num_threads: 8
         prefix: libgomp
       filepath: /<PATH>/lib/python3.12/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
        version: None

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0