8000 [MRG+1] fix StratifiedShuffleSplit with 2d y (#9044) · plagree/scikit-learn@93563b0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 93563b0

Browse files
veneMechCoder
authored andcommitted
[MRG+1] fix StratifiedShuffleSplit with 2d y (scikit-learn#9044)
* regression test and fix for 2d stratified shuffle split * strengthen non-overlap sss tests * clarify test and comment * remove iter from tests, use str instead of hash
1 parent 2c21479 commit 93563b0

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

sklearn/model_selection/_split.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,6 +1478,11 @@ def _iter_indices(self, X, y, groups=None):
14781478
y = check_array(y, ensure_2d=False, dtype=None)
14791479
n_train, n_test = _validate_shuffle_split(n_samples, self.test_size,
14801480
self.train_size)
1481+
1482+
if y.ndim == 2:
1483+
# for multi-label y, map each distinct row to its string repr:
1484+
y = np.array([str(row) for row in y])
1485+
14811486
classes, y_indices = np.unique(y, return_inverse=True)
14821487
n_classes = classes.shape[0]
14831488

sklearn/model_selection/tests/test_split.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,10 +663,37 @@ def test_stratified_shuffle_split_overlap_train_test_bug():
663663
sss = StratifiedShuffleSplit(n_splits=1,
664664
test_size=0.5, random_state=0)
665665

666-
train, test = next(iter(sss.split(X=X, y=y)))
666+
train, test = next(sss.split(X=X, y=y))
667667

668+
# no overlap
668669
assert_array_equal(np.intersect1d(train, test), [])
669670

671+
# complete partition
672+
assert_array_equal(np.union1d(train, test), np.arange(len(y)))
673+
674+
675+
def test_stratified_shuffle_split_multilabel():
676+
# fix for issue 9037
677+
for y in [np.array([[0, 1], [1, 0], [1, 0], [0, 1]]),
678+
np.array([[0, 1], [1, 1], [1, 1], [0, 1]])]:
679+
X = np.ones_like(y)
680+
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=0)
681+
train, test = next(sss.split(X=X, y=y))
682+
y_train = y[train]
683+
y_test = y[test]
684+
685+
# no overlap
686+
assert_array_equal(np.intersect1d(train, test), [])
687+
688+
# complete partition
689+
assert_array_equal(np.union1d(train, test), np.arange(len(y)))
690+
691+
# correct stratification of entire rows
692+
# (by design, here y[:, 0] uniquely determines the entire row of y)
693+
expected_ratio = np.mean(y[:, 0])
694+
assert_equal(expected_ratio, np.mean(y_train[:, 0]))
695+
assert_equal(expected_ratio, np.mean(y_test[:, 0]))
696+
670697

671698
def test_predefinedsplit_with_kfold_split():
672699
# Check that PredefinedSplit can reproduce a split generated by Kfold.

0 commit comments

Comments
 (0)
0