Description
Describe the bug
StratifiedGroupKFold.split
has the signature (self, X, y=None, groups=None)
indicating that both y
, and groups
may not be specified when calling split
.
However, omitting only groups
results in TypeError: iteration over a 0-d array
. Also, when omitting both y
and groups
, or only y
the result is ValueError: Supported target types are: ('binary', 'multiclass'). Got 'unknown' instead.
This indicates, contrary to the signature, that y
and `groups are required and not optional.
I would instead expect consisted behavior with e.g. StratifiedKFold
, where the y
parameter to split
is not optional.
StratifiedKFold
and StratifiedGroupKFold
both inherit from _BaseKFold
, which provides .split
. However StratifiedKFold
implements its own split
method, instead of using _BaseKFold
like StratifiedGroupKFold
does.
Steps/Code to Reproduce
import numpy as np
from sklearn.model_selection import StratifiedGroupKFold
rng = np.random.default_rng()
X = rng.normal(size=(10, 3))
y = np.concatenate((np.ones(5, dtype=int), np.zeros(5, dtype=int)))
g = np.tile([1, 0], 5)
sgkf = StratifiedGroupKFold(n_splits=5)
next(sgkf.split(X=X, y=y, groups=None)) # TypeError
sgkf = StratifiedGroupKFold(n_splits=5)
next(sgkf.split(X=X, y=None, groups=None)) # ValueError
sgkf = StratifiedGroupKFold(n_splits=5)
next(sgkf.split(X=X, y=None, groups=g)) # ValueError
Expected Results
Either no error if y
, groups
, or both are not specified. Or remove the default of None
for both parameters from the function signature.
Actual Results
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[2], line 2
1 sgkf = StratifiedGroupKFold(n_splits=5)
----> 2 next(sgkf.split(X=X, y=y, groups=None)) # TypeError
File /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:411, in _BaseKFold.split(self, X, y, groups)
403 if self.n_splits > n_samples:
404 raise ValueError(
405 (
406 "Cannot have number of splits n_splits={0} greater"
407 " than the number of samples: n_samples={1}."
408 ).format(self.n_splits, n_samples)
409 )
--> 411 for train, test in super().split(X, y, groups):
412 yield train, test
File /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:142, in BaseCrossValidator.split(self, X, y, groups)
140 X, y, groups = indexable(X, y, groups)
141 indices = np.arange(_num_samples(X))
--> 142 for test_index in self._iter_test_masks(X, y, groups):
143 train_index = indices[np.logical_not(test_index)]
144 test_index = indices[test_index]
File /<PATH>/ib/python3.12/site-packages/sklearn/model_selection/_split.py:154, in BaseCrossValidator._iter_test_masks(self, X, y, groups)
149 def _iter_test_masks(self, X=None, y=None, groups=None):
150 """Generates boolean masks corresponding to test sets.
151
152 By default, delegates to _iter_test_indices(X, y, groups)
153 """
--> 154 for test_index in self._iter_test_indices(X, y, groups):
155 test_mask = np.zeros(_num_samples(X), dtype=bool)
156 test_mask[test_index] = True
File /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:1035, in StratifiedGroupKFold._iter_test_indices(self, X, y, groups)
1031 _, groups_inv, groups_cnt = np.unique(
1032 groups, return_inverse=True, return_counts=True
1033 )
1034 y_counts_per_group = np.zeros((len(groups_cnt), n_classes))
-> 1035 for class_idx, group_idx in zip(y_inv, groups_inv):
1036 y_counts_per_group[group_idx, class_idx] += 1
1038 y_counts_per_fold = np.zeros((self.n_splits, n_classes))
TypeError: iteration over a 0-d array
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[13], line 2
1 sgkf = StratifiedGroupKFold(n_splits=5)
----> 2 next(sgkf.split(X=X, y=None, groups=g))
File /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:411, in _BaseKFold.split(self, X, y, groups)
403 if self.n_splits > n_samples:
404 raise ValueError(
405 (
406 "Cannot have number of splits n_splits={0} greater"
407 " than the number of samples: n_samples={1}."
408 ).format(self.n_splits, n_samples)
409 )
--> 411 for train, test in super().split(X, y, groups):
412 yield train, test
File /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:142, in BaseCrossValidator.split(self, X, y, groups)
140 X, y, groups = indexable(X, y, groups)
141 indices = np.arange(_num_samples(X))
--> 142 for test_index in self._iter_test_masks(X, y, groups):
143 train_index = indices[np.logical_not(test_index)]
144 test_index = indices[test_index]
File /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:154, in BaseCrossValidator._iter_test_masks(self, X, y, groups)
149 def _iter_test_masks(self, X=None, y=None, groups=None):
150 """Generates boolean masks corresponding to test sets.
151
152 By default, delegates to _iter_test_indices(X, y, groups)
153 """
--> 154 for test_index in self._iter_test_indices(X, y, groups):
155 test_mask = np.zeros(_num_samples(X), dtype=bool)
156 test_mask[test_index] = True
File /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:1008, in StratifiedGroupKFold._iter_test_indices(self, X, y, groups)
1006 allowed_target_types = ("binary", "multiclass")
1007 if type_of_target_y not in allowed_target_types:
-> 1008 raise ValueError(
1009 "Supported target types are: {}. Got {!r} instead.".format(
1010 allowed_target_types, type_of_target_y
1011 )
1012 )
1014 y = column_or_1d(y)
1015 _, y_inv, y_cnt = np.unique(y, return_inverse=True, return_counts=True)
ValueError: Supported target types are: ('binary', 'multiclass'). Got 'unknown' instead.
Versions
System:
python: 3.12.4 (main, Jul 23 2024, 09:14:16) [GCC 14.1.1 20240522]
executable: /<PATH>/bin/python
machine: Linux-6.12.9-arch1-1-x86_64-with-glibc2.40
Python dependencies:
sklearn: 1.6.1
pip: 24.3.1
setuptools: None
numpy: 2.2.2
scipy: 1.15.1
Cython: None
pandas: 2.2.3
matplotlib: 3.10.0
joblib: 1.4.2
threadpoolctl: 3.5.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
num_threads: 8
prefix: libscipy_openblas
filepath: /<PATH>/lib/python3.12/site-packages/numpy.libs/libscipy_openblas64_-6bb31eeb.so
version: 0.3.28
threading_layer: pthreads
architecture: SkylakeX
user_api: blas
internal_api: openblas
num_threads: 8
prefix: libscipy_openblas
filepath: /<PATH>/lib/python3.12/site-packages/scipy.libs/libscipy_openblas-68440149.so
version: 0.3.28
threading_layer: pthreads
architecture: SkylakeX
user_api: openmp
internal_api: openmp
num_threads: 8
prefix: libgomp
filepath: /<PATH>/lib/python3.12/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
version: None