10000 `y`, and `groups` parameters to`StratifiedGroupKFold.split()` are optional · Issue #30742 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

y, and groups parameters toStratifiedGroupKFold.split() are optional #30742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Teagum opened this issue Jan 31, 2025 · 8 comments
Open

y, and groups parameters toStratifiedGroupKFold.split() are optional #30742

Teagum opened this issue Jan 31, 2025 · 8 comments
Labels
Documentation Validation related to input validation

Comments

@Teagum
Copy link
Teagum commented Jan 31, 2025

Describe the bug

StratifiedGroupKFold.split has the signature (self, X, y=None, groups=None) indicating that both y, and groups may not be specified when calling split.

However, omitting only groups results in TypeError: iteration over a 0-d array. Also, when omitting both y and groups, or only y the result is ValueError: Supported target types are: ('binary', 'multiclass'). Got 'unknown' instead. This indicates, contrary to the signature, that y and `groups are required and not optional.

I would instead expect consisted behavior with e.g. StratifiedKFold, where the y parameter to split is not optional.

StratifiedKFold and StratifiedGroupKFold both inherit from _BaseKFold, which provides .split. However StratifiedKFold implements its own split method, instead of using _BaseKFold like StratifiedGroupKFold does.

Steps/Code to Reproduce

import numpy as np
from sklearn.model_selection import StratifiedGroupKFold

rng = np.random.default_rng()

X = rng.normal(size=(10, 3))
y = np.concatenate((np.ones(5, dtype=int), np.zeros(5, dtype=int)))
g = np.tile([1, 0], 5)

sgkf = StratifiedGroupKFold(n_splits=5)
next(sgkf.split(X=X, y=y, groups=None))           # TypeError

sgkf = StratifiedGroupKFold(n_splits=5)
next(sgkf.split(X=X, y=None, groups=None))    # ValueError

sgkf = StratifiedGroupKFold(n_splits=5)
next(sgkf.split(X=X, y=None, groups=g))          # ValueError

Expected Results

Either no error if y, groups, or both are not specified. Or remove the default of None for both parameters from the function signature.

Actual Results

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[2], line 2
      1 sgkf = StratifiedGroupKFold(n_splits=5)
----> 2 next(sgkf.split(X=X, y=y, groups=None))    # TypeError

File /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:411, in _BaseKFold.split(self, X, y, groups)
    403 if self.n_splits > n_samples:
    404     raise ValueError(
    405         (
    406             "Cannot have number of splits n_splits={0} greater"
    407             " than the number of samples: n_samples={1}."
    408         ).format(self.n_splits, n_samples)
    409     )
--> 411 for train, test in super().split(X, y, groups):
    412     yield train, test

File /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:142, in BaseCrossValidator.split(self, X, y, groups)
    140 X, y, groups = indexable(X, y, groups)
    141 indices = np.arange(_num_samples(X))
--> 142 for test_index in self._iter_test_masks(X, y, groups):
    143     train_index = indices[np.logical_not(test_index)]
    144     test_index = indices[test_index]

File /<PATH>/ib/python3.12/site-packages/sklearn/model_selection/_split.py:154, in BaseCrossValidator._iter_test_masks(self, X, y, groups)
    149 def _iter_test_masks(self, X=None, y=None, groups=None):
    150     """Generates boolean masks corresponding to test sets.
    151 
    152     By default, delegates to _iter_test_indices(X, y, groups)
    153     """
--> 154     for test_index in self._iter_test_indices(X, y, groups):
    155         test_mask = np.zeros(_num_samples(X), dtype=bool)
    156         test_mask[test_index] = True

File /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:1035, in StratifiedGroupKFold._iter_test_indices(self, X, y, groups)
   1031 _, groups_inv, groups_cnt = np.unique(
   1032     groups, return_inverse=True, return_counts=True
   1033 )
   1034 y_counts_per_group = np.zeros((len(groups_cnt), n_classes))
-> 1035 for class_idx, group_idx in zip(y_inv, groups_inv):
   1036     y_counts_per_group[group_idx, class_idx] += 1
   1038 y_counts_per_fold = np.zeros((self.n_splits, n_classes))

TypeError: iteration over a 0-d array

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[13], line 2
      1 sgkf = StratifiedGroupKFold(n_splits=5)
----> 2 next(sgkf.split(X=X, y=None, groups=g))

File  /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:411, in _BaseKFold.split(self, X, y, groups)
    403 if self.n_splits > n_samples:
    404     raise ValueError(
    405         (
    406             "Cannot have number of splits n_splits={0} greater"
    407             " than the number of samples: n_samples={1}."
    408         ).format(self.n_splits, n_samples)
    409     )
--> 411 for train, test in super().split(X, y, groups):
    412     yield train, test

File  /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:142, in BaseCrossValidator.split(self, X, y, groups)
    140 X, y, groups = indexable(X, y, groups)
    141 indices = np.arange(_num_samples(X))
--> 142 for test_index in self._iter_test_masks(X, y, groups):
    143     train_index = indices[np.logical_not(test_index)]
    144     test_index = indices[test_index]

File  /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:154, in BaseCrossValidator._iter_test_masks(self, X, y, groups)
    149 def _iter_test_masks(self, X=None, y=None, groups=None):
    150     """Generates boolean masks corresponding to test sets.
    151 
    152     By default, delegates to _iter_test_indices(X, y, groups)
    153     """
--> 154     for test_index in self._iter_test_indices(X, y, groups):
    155         test_mask = np.zeros(_num_samples(X), dtype=bool)
    156         test_mask[test_index] = True

File  /<PATH>/lib/python3.12/site-packages/sklearn/model_selection/_split.py:1008, in StratifiedGroupKFold._iter_test_indices(self, X, y, groups)
   1006 allowed_target_types = ("binary", "multiclass")
   1007 if type_of_target_y not in allowed_target_types:
-> 1008     raise ValueError(
   1009         "Supported target types are: {}. Got {!r} instead.".format(
   1010             allowed_target_types, type_of_target_y
   1011         )
   1012     )
   1014 y = column_or_1d(y)
   1015 _, y_inv, y_cnt = np.unique(y, return_inverse=True, return_counts=True)

ValueError: Supported target types are: ('binary', 'multiclass'). Got 'unknown' instead.

Versions

System:
    python: 3.12.4 (main, Jul 23 2024, 09:14:16) [GCC 14.1.1 20240522]
executable: /<PATH>/bin/python
   machine: Linux-6.12.9-arch1-1-x86_64-with-glibc2.40

Python dependencies:
      sklearn: 1.6.1
          pip: 24.3.1
   setuptools: None
        numpy: 2.2.2
        scipy: 1.15.1
       Cython: None
       pandas: 2.2.3
   matplotlib: 3.10.0
       joblib: 1.4.2
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 8
         prefix: libscipy_openblas
       filepath: /<PATH>/lib/python3.12/site-packages/numpy.libs/libscipy_openblas64_-6bb31eeb.so
        version: 0.3.28
threading_layer: pthreads
   architecture: SkylakeX

       user_api: blas
   internal_api: openblas
    num_threads: 8
         prefix: libscipy_openblas
       filepath: /<PATH>/lib/python3.12/site-packages/scipy.libs/libscipy_openblas-68440149.so
        version: 0.3.28
threading_layer: pthreads
   architecture: SkylakeX

       user_api: openmp
   internal_api: openmp
    num_threads: 8
         prefix: libgomp
       filepath: /<PATH>/lib/python3.12/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
        version: None
@Teagum Teagum added Bug Needs Triage Issue requires triage labels Jan 31, 2025
@StefanieSenger
Copy link
Contributor

Hello @Teagum,

thanks for taking the time to make this issue. StratifiedGroupKFold needs both, y and groups and in its split() both need to be provided. I agree that the documentation on this is inconsistent with other splitters and confusing.

I have seen that also the docs for GroupKFold.split() suggest that users could pass groups=None, but in here, we raise an error with a proper error message (ValueError: The 'groups' parameter should not be None.), which is more user friendly.

I think we should improve the documentation and also make sure we raise a helpful error messages. (I actually wonder why we haven't used @validate_params here.)

I will go ahead and make a PR to fix this, unless you want to do it @Teagum?

@StefanieSenger StefanieSenger added Validation related to input validation Documentation and removed Bug Needs Triage Issue requires triage labels Jan 31, 2025
@Teagum
Copy link
Author
Teagum commented Feb 1, 2025

Seems like a good first issue. So let me try it!

@StefanieSenger
Copy link
Contributor
StefanieSenger commented Feb 1, 2025

Sure. And let's get a core-dev's confirmation before you put too much work in. Maybe there is a reason to it that I cannot see.

Maybe @adrinjalali, would you confirm that it makes sense to make y and groups required inputs into splitters split() method?
Edit: I of cause had meant y and groups required inputs in to StratifiedGroupKFold .split(); not all the splitters.

@adrinjalali
Copy link
Member

This is an interesting one, and tangentially related to #26821.

Basically, historically, all splitters have had y and groups in their signature, even the ones which would never use groups.

That was before metadata routing was a thing. These days, we could potentially consolidate a lot of our splitters into very few (if not one) classes/class, and to routing and validation in a nice consistent way.

The proposed "fix" here would be somewhat of a patch work, which we can apply to all relevant classes here, but it would be a sort of a short term fix.

I'm honestly not sure what the path forward in this particular case should be. In an ideal world, I'd like to see the large refactoring done in the splitter part of the code base, but I understand that's not an easy project to tackle, and I'm not sure when I'd have the chance to tackle it myself.

cc @scikit-learn/core-devs for consult really 😁

@StefanieSenger
Copy link
Contributor

Thank you @adrinjalali. I can see that we might want to re-design all the splitters in the future now with metadata routing. But we're talking about a two(-or-more)-year perspective here. Maybe it's okay to start in small steps, and do some intermediary short-term fixes that lead to the splitters become more equal to each other in the meantime?

I don't think that it would be an API-relevant change if users get to see better error messages and have a clearer documentation until we know what to do with the splitters in general.

@Teagum
Copy link
Author
Teagum commented Feb 5, 2025

Alright, so I'll start a PR for this.

@hoipranav
Copy link

Is this issue still open to work on?

@Teagum
Copy link
Author
Teagum commented Feb 18, 2025

Hi, @hoipranav! Thank you for your interest in this issue. I already decided to work on this a few weeks back, but didn't have the time to get into it. It would be great if you would let me finish this. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Documentation Validation related to input validation
Projects
None yet
Development

No branches or pull requests

4 participants
0