-
-
Notifications
You must be signed in to change notification settings - Fork 26.1k
Description
Describe the bug
The existing implementation of the "StratifiedGroupKFold" class does not consistently achieve accurate stratified splits when dividing datasets into subsets, particularly when the dataset contains a relatively small number of samples.
None of the resulting splits guarantee the presence of at least one sample from every class in both the training and testing sets.
This issue can be better illustrated with an example.
Steps/Code to Reproduce
X = np.ones((6, 2))
y = np.array([1, 1, 0, 0, 2, 2])
groups = np.array(["a", "b", "b", "c", "c", "d"])
sgkf = StratifiedGroupKFold(n_splits=2, random_state=3, shuffle=True)
sgkf.get_n_splits(X, y)
for i, (train_index, test_index) in enumerate(sgkf.split(X, y, groups)):
print(f"Fold {i}:")
print(f" Train: index={train_index}")
print(f" classes={y[train_index]}")
print(f" group={groups[train_index]}")
print(f" Test: index={test_index}")
print(f" classes={y[test_index]}")
print(f" group={groups[test_index]}")
Expected Results
One of the expected results is the following possible result:
Fold 0:
Train: index=[1 2 5]
classes=[1 0 2]
group=['b' 'b' 'd']
Test: index=[0 3 4]
classes=[1 0 2]
group=['a' 'c' 'c']
Fold 1:
Train: index=[0 3 4]
classes=[1 0 2]
group=['a' 'c' 'c']
Test: index=[1 2 5]
classes=[1 0 2]
group=['b' 'b' 'd']
This result is obtained by changing the random_state
in StratifiedGroupKFold
to 5.
The error is either:
- the randomization used in
StratifiedGroupKFold
as presented in issue Improving stratification in StratifiedGroupKFold #24656 . It could be problematic when we want to implement an object close to theRepeatedStratifiedGroupKFold
as presented in issue Add RepeatedStratifiedGroupKFold #24247 . - the class repartition is not ensuring by the object, and the trainset might not contain all classes. This could be a real problem when using the splitter in an automated pipeline.
Actual Results
Results of the run:
Fold 0:
Train: index=[1 2]
classes=[1 0]
group=['b' 'b']
Test: index=[0 3 4 5]
classes=[1 0 2 2]
group=['a' 'c' 'c' 'd']
Fold 1:
Train: index=[0 3 4 5]
classes=[1 0 2 2]
group=['a' 'c' 'c' 'd']
Test: index=[1 2]
classes=[1 0]
group=['b' 'b']
In this example, the splits produced by the model exhibit an inconsistency in the class distribution. Specifically, class 2 is absent from the training set in Fold 1
and so absent from the testing set in Fold 2
. When applying a classifier to this split, it results in an error due to the incomplete exposure to all classes during the training phase.
Versions
System:
python: 3.11.4 (main, Jun 20 2023, 17:23:00) [Clang 14.0.3 (clang-1403.0.22.14.1)]
executable: /Users/xavierdurand/Documents/SurgeLibrary/venv/bin/python
machine: macOS
6063
-13.4.1-arm64-arm-64bit
Python dependencies:
sklearn: 1.3.0
pip: 23.0.1
setuptools: 67.6.1
numpy: 1.23.2
scipy: 1.11.1
Cython: 0.29.36
pandas: 1.5.3
matplotlib: 3.7.2
joblib: 1.3.1
threadpoolctl: 3.1.0
Built with OpenMP: True
threadpoolctl info:
user_api: openmp
internal_api: openmp
prefix: libomp
filepath: /Users/xavierdurand/Documents/SurgeLibrary/venv/lib/python3.11/site-packages/sklearn/.dylibs/libomp.dylib
version: None
num_threads: 10
user_api: blas
internal_api: openblas
prefix: libopenblas
filepath: /Users/xavierdurand/Documents/SurgeLibrary/venv/lib/python3.11/site-packages/numpy/.dylibs/libopenblas64_.0.dylib
version: 0.3.20
threading_layer: pthreads
architecture: armv8
num_threads: 10
user_api: blas
internal_api: openblas
prefix: libopenblas
filepath: /Users/xavierdurand/Documents/SurgeLibrary/venv/lib/python3.11/site-packages/scipy/.dylibs/libopenblas.0.dylib
version: 0.3.21.dev
threading_layer: pthreads
architecture: armv8
num_threads: 10