8000 Merge pull request #6379 from lesteve/fix-stratified-shuffle-split-tr… · raghavrv/scikit-learn@150afe6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 150afe6

Browse files
committed
Merge pull request scikit-learn#6379 from lesteve/fix-stratified-shuffle-split-train-test-overlap
[MRG+1] fix StratifiedShuffleSplit train and test overlap
2 parents 1049642 + 07728d9 commit 150afe6

File tree

5 files changed

+47
-6
lines changed

5 files changed

+47
-6
lines changed

doc/whats_new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ Bug fixes
142142
``transform`` or ``predict_proba`` are called on the non-fitted estimator.
143143
by `Sebastian Raschka`_.
144144

145+
- Fixed bug in :class:`model_selection.StratifiedShuffleSplit`
146+
where train and test sample could overlap in some edge cases,
147+
see `#6121 <https://github.com/scikit-learn/scikit-learn/issues/6121>`_ for
148+
more details. By `Loic Esteve`_.
149+
145150
API changes summary
146151
-------------------
147152

sklearn/cross_validation.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,14 +1010,19 @@ def _iter_indices(self):
10101010
# Because of rounding issues (as n_train and n_test are not
10111011
# dividers of the number of elements per class), we may end
10121012
# up here with less samples in train and test than asked for.
1013-
if len(train) < self.n_train or len(test) < self.n_test:
1013+
if len(train) + len(test) < self.n_train + self.n_test:
10141014
# We complete by affecting randomly the missing indexes
10151015
missing_idx = np.where(bincount(train + test,
10161016
minlength=len(self.y)) == 0,
10171017
)[0]
10181018
missing_idx = rng.permutation(missing_idx)
1019-
train.extend(missing_idx[:(self.n_train - len(train))])
1020-
test.extend(missing_idx[-(self.n_test - len(test)):])
1019+
n_missing_train = self.n_train - len(train)
1020+
n_missing_test = self.n_test - len(test)
1021+
1022+
if n_missing_train > 0:
1023+
train.extend(missing_idx[:n_missing_train])
1024+
if n_missing_test > 0:
1025+
test.extend(missing_idx[-n_missing_test:])
10211026

10221027
train = rng.permutation(train)
10231028
test = rng.permutation(test)

sklearn/model_selection/_split.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,13 +1108,18 @@ def _iter_indices(self, X, y, labels=None):
11081108
# Because of rounding issues (as n_train and n_test are not
11091109
# dividers of the number of elements per class), we may end
11101110
# up here with less samples in train and test than asked for.
1111-
if len(train) < n_train or len(test) < n_test:
1111+
if len(train) + len(test) < n_train + n_test:
11121112
# We complete by affecting randomly the missing indexes
11131113
missing_indices = np.where(bincount(train + test,
11141114
minlength=len(y)) == 0)[0]
11151115
missing_indices< 10000 /span> = rng.permutation(missing_indices)
1116-
train.extend(missing_indices[:(n_train - len(train))])
1117-
test.extend(missing_indices[-(n_test - len(test)):])
1116+
n_missing_train = n_train - len(train)
1117+
n_missing_test = n_test - len(test)
1118+
1119+
if n_missing_train > 0:
1120+
train.extend(missing_indices[:n_missing_train])
1121+
if n_missing_test > 0:
1122+
test.extend(missing_indices[-n_missing_test:])
11181123

11191124
train = rng.permutation(train)
11201125
test = rng.permutation(test)

sklearn/model_selection/tests/test_split.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,20 @@ def assert_counts_are_ok(idx_counts, p):
610610
assert_counts_are_ok(test_counts, ex_test_p)
611611

612612

613+
def test_stratified_shuffle_split_overlap_train_test_bug():
614+
# See https://github.com/scikit-learn/scikit-learn/issues/6121 for
615+
# the original bug report
616+
y = [0, 1, 2, 3] * 3 + [4, 5] * 5
617+
X = np.ones_like(y)
618+
619+
splits = StratifiedShuffleSplit(n_iter=1,
620+
test_size=0.5, random_state=0)
621+
622+
train, test = next(iter(splits.split(X=X, y=y)))
623+
624+
assert_array_equal(np.intersect1d(train, test), [])
625+
626+
613627
def test_predefinedsplit_with_kfold_split():
614628
# Check that PredefinedSplit can reproduce a split generated by Kfold.
615629
folds = -1 * np.ones(10)

sklearn/tests/test_cross_validation.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,18 @@ def assert_counts_are_ok(idx_counts, p):
546546
assert_counts_are_ok(test_counts, ex_test_p)
547547

548548

549+
def test_stratified_shuffle_split_overlap_train_test_bug():
550+
# See https://github.com/scikit-learn/scikit-learn/issues/6121 for
551+
# the original bug report
552+
labels = [0, 1, 2, 3] * 3 + [4, 5] * 5
553+
554+
splits = cval.StratifiedShuffleSplit(labels, n_iter=1,
555+
test_size=0.5, random_state=0)
556+
train, test = next(iter(splits))
557+
558+
assert_array_equal(np.intersect1d(train, test), [])
559+
560+
549561
def test_predefinedsplit_with_kfold_split():
550562
# Check that PredefinedSplit can reproduce a split generated by Kfold.
551563
folds = -1 * np.ones(10)

0 commit comments

Comments
 (0)
0