8000 StratifiedGroupKFold not ensuring Stratified splits · Issue #28218 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
StratifiedGroupKFold not ensuring Stratified splits #28218
@xavitator

Description

@xavitator

Describe the bug

The existing implementation of the "StratifiedGroupKFold" class does not consistently achieve accurate stratified splits when dividing datasets into subsets, particularly when the dataset contains a relatively small number of samples.
None of the resulting splits guarantee the presence of at least one sample from every class in both the training and testing sets.

This issue can be better illustrated with an example.

Steps/Code to Reproduce

X = np.ones((6, 2))
y = np.array([1, 1, 0, 0, 2, 2])
groups = np.array(["a", "b", "b", "c", "c", "d"])
sgkf = StratifiedGroupKFold(n_splits=2, random_state=3, shuffle=True)
sgkf.get_n_splits(X, y)
for i, (train_index, test_index) in enumerate(sgkf.split(X, y, groups)):
    print(f"Fold {i}:")
    print(f"  Train: index={train_index}")
    print(f"       classes={y[train_index]}")
    print(f"         group={groups[train_index]}")
    print(f"  Test:  index={test_index}")
    print(f"       classes={y[test_index]}")
    print(f"         group={groups[test_index]}")

Expected Results

One of the expected results is the following possible result:

Fold 0:
  Train: index=[1 2 5]
       classes=[1 0 2]
         group=['b' 'b' 'd']
  Test:  index=[0 3 4]
       classes=[1 0 2]
         group=['a' 'c' 'c']
Fold 1:
  Train: index=[0 3 4]
       classes=[1 0 2]
         group=['a' 'c' 'c']
  Test:  index=[1 2 5]
       classes=[1 0 2]
         group=['b' 'b' 'd']

This result is obtained by changing the random_state in StratifiedGroupKFold to 5.
The error is either:

Actual Results

Results of the run:

Fold 0:
  Train: index=[1 2]
       classes=[1 0]
         group=['b' 'b']
  Test:  index=[0 3 4 5]
       classes=[1 0 2 2]
         group=['a' 'c' 'c' 'd']
Fold 1:
  Train: index=[0 3 4 5]
       classes=[1 0 2 2]
         group=['a' 'c' 'c' 'd']
  Test:  index=[1 2]
       classes=[1 0]
         group=['b' 'b']

In this example, the splits produced by the model exhibit an inconsistency in the class distribution. Specifically, class 2 is absent from the training set in Fold 1 and so absent from the testing set in Fold 2. When applying a classifier to this split, it results in an error due to the incomplete exposure to all classes during the training phase.

Versions

System:
    python: 3.11.4 (main, Jun 20 2023, 17:23:00) [Clang 14.0.3 (clang-1403.0.22.14.1)]
executable: /Users/xavierdurand/Documents/SurgeLibrary/venv/bin/python
   machine: macOS
6063
-13.4.1-arm64-arm-64bit

Python dependencies:
      sklearn: 1.3.0
          pip: 23.0.1
   setuptools: 67.6.1
        numpy: 1.23.2
        scipy: 1.11.1
       Cython: 0.29.36
       pandas: 1.5.3
   matplotlib: 3.7.2
       joblib: 1.3.1
threadpoolctl: 3.1.0

Built with OpenMP: True

threadpoolctl info:
       user_api: openmp
   internal_api: openmp
         prefix: libomp
       filepath: /Users/xavierdurand/Documents/SurgeLibrary/venv/lib/python3.11/site-packages/sklearn/.dylibs/libomp.dylib
        version: None
    num_threads: 10

       user_api: blas
   internal_api: openblas
         prefix: libopenblas
       filepath: /Users/xavierdurand/Documents/SurgeLibrary/venv/lib/python3.11/site-packages/numpy/.dylibs/libopenblas64_.0.dylib
        version: 0.3.20
threading_layer: pthreads
   architecture: armv8
    num_threads: 10

       user_api: blas
   internal_api: openblas
         prefix: libopenblas
       filepath: /Users/xavierdurand/Documents/SurgeLibrary/venv/lib/python3.11/site-packages/scipy/.dylibs/libopenblas.0.dylib
        version: 0.3.21.dev
threading_layer: pthreads
   architecture: armv8
    num_threads: 10

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0