8000 From #4583 - LabelShuffleSplit and tests · raghavrv/scikit-learn@c7354ed · GitHub
[go: up one dir, main page]

Skip to content

Commit c7354ed

Browse files
committed
From scikit-learn#4583 - LabelShuffleSplit and tests
1 parent 0e2e54c commit c7354ed

File tree

3 files changed

+173
-60
lines changed

3 files changed

+173
-60
lines changed

sklearn/model_selection/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ._split import LeavePLabelOut
88
from ._split import LeavePOut
99
from ._split import ShuffleSplit
10+
from ._split import LabelShuffleSplit
1011
from ._split import StratifiedShuffleSplit
1112
from ._split import PredefinedSplit
1213
from ._split import train_test_split
@@ -27,7 +28,8 @@
2728
__all__ = ('BaseCrossValidator', 'GridSearchCV', 'KFold', 'LabelKFold',
2829
'LeaveOneLabelOut', 'LeaveOneOut', 'LeavePLabelOut', 'LeavePOut',
2930
'ParameterGrid', 'ParameterSampler', 'PredefinedSplit',
30-
'RandomizedSearchCV', 'ShuffleSplit', 'StratifiedKFold',
31-
'StratifiedShuffleSplit', 'check_cv', 'cross_val_predict',
32-
'cross_val_score', 'fit_grid_point', 'learning_curve',
33-
'permutation_test_score', 'train_test_split', 'validation_curve')
31+
'RandomizedSearchCV', 'ShuffleSplit', 'LabelShuffleSplit',
32+
'StratifiedKFold', 'StratifiedShuffleSplit', 'check_cv',
33+
'cross_val_predict', 'cross_val_score', 'fit_grid_point',
34+
'learning_curve', 'permutation_test_score', 'train_test_split',
35+
'validation_curve')

sklearn/model_selection/_split.py

Lines changed: 121 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
'LeavePLabelOut',
4040
'LeavePOut',
4141
'ShuffleSplit',
42+
'LabelShuffleSplit',
4243
'StratifiedKFold',
4344
'StratifiedShuffleSplit',
4445
'PredefinedSplit',
@@ -825,70 +826,70 @@ def get_n_splits(self, X=None, y=None, labels=None):
825826
return self.n_iter
826827

827828

828-
def _validate_shuffle_split_init(test_size, train_size):
829-
if test_size is None and train_size is None:
830-
raise ValueError('test_size and train_size can not both be None')
829+
class LabelShuffleSplit(ShuffleSplit):
830+
'''Shuffle-Labels-Out cross-validation iterator
831831

832-
if test_size is not None:
833-
if np.asarray(test_size).dtype.kind == 'f':
834-
if test_size >= 1.:
835-
raise ValueError(
836-
'test_size=%f should be smaller '
837-
'than 1.0 or be an integer' % test_size)
838-
elif np.asarray(test_size).dtype.kind != 'i':
839-
# int values are checked during split based on the input
840-
raise ValueError("Invalid value for test_size: %r" % test_size)
832+
Provides randomized train/test indices to split data according to a
833+
third-party provided label. This label information can be used to encode
834+
arbitrary domain specific stratifications of the samples as integers.
841835

842-
if train_size is not None:
843-
if np.asarray(train_size).dtype.kind == 'f':
844-
if train_size >= 1.:
845-
raise ValueError("train_size=%f should be smaller "
846-
"than 1.0 or be an integer" % train_size)
847-
elif ((np.asarray(test_size).dtype.kind == 'f') and
848-
((train_size + test_size) > 1.)):
849-
raise ValueError('The sum of test_size and train_size = %f, '
850-
'should be smaller than 1.0. Reduce '
851-
'test_size and/or train_size.' %
852-
(train_size + test_size))
853-
elif np.asarray(train_size).dtype.kind != 'i':
854-
# int values are checked during split based on the input
855-
raise ValueError("Invalid value for train_size: %r" % train_size)
836+
For instance the labels could be the year of collection of the samples
837+
and thus allow for cross-validation against time-based splits.
856838

839+
The difference between LeavePLabelOut and LabelShuffleSplit is that
840+
the former generates splits using all subsets of size ``p`` unique labels,
841+
whereas LabelShuffleSplit generates a user-determined number of random
842+
test splits, each with a user-determined fraction of unique labels.
857843

858-
def _validate_shuffle_split(n, test_size, train_size):
859-
if ((test_size is not None) and (np.asarray(test_size).dtype.kind == 'i')
860-
and (test_size >= n)):
861-
raise ValueError('test_size=%d should be smaller '
862-
'than the number of samples %d' % (test_size, n))
844+
For exa 6D40 mple, a less computationally intensive alternative to
845+
``LeavePLabelOut(p=10)`` would be
846+
``LabelShuffleSplit(test_size=10, n_iter=100)``.
863847

864-
if ((train_size is not None) and (np.asarray(train_size).dtype.kind == 'i')
865-
and (train_size >= n)):
866-
raise ValueError("train_size=%d should be smaller "
867-
"than the number of samples %d" % (train_size, n))
848+
Note: The parameters ``test_size`` and ``train_size`` refer to labels, and
849+
not to samples, as in ShuffleSplit.
868850

869-
if np.asarray(test_size).dtype.kind == 'f':
870-
n_test = ceil(test_size * n)
871-
elif np.asarray(test_size).dtype.kind == 'i':
872-
n_test = float(test_size)
873851

874-
if train_size is None:
875-
n_train = n - n_test
876-
else:
877-
if np.asarray(train_size).dtype.kind == 'f':
878-
n_train = floor(train_size * n)
879-
else:
880-
n_train = float(train_size)
852+
Parameters
853+
----------
854+
n_iter : int (default 5)
855+
Number of re-shuffling & splitting iterations.
881856

882-
if test_size is None:
883-
n_test = n - n_train
857+
test_size : float (default 0.2), int, or None
858+
If float, should be between 0.0 and 1.0 and represent the
859+
proportion of the labels to include in the test split. If
860+
int, represents the absolute number of test labels. If None,
861+
the value is automatically set to the complement of the train size.
884862

885-
if n_train + n_test > n:
886-
raise ValueError('The sum of train_size and test_size = %d, '
887-
'should be smaller than the number of '
888-
'samples %d. Reduce test_size and/or '
889-
'train_size.' % (n_train + n_test, n))
863+
train_size : float, int, or None (default is None)
864+
If float, should be between 0.0 and 1.0 and represent the
865+
proportion of the labels to include in the train split. If
866+
int, represents the absolute number of train labels. If None,
867+
the value is automatically set to the complement of the test size.
890868

891-
return int(n_train), int(n_test)
869+
random_state : int or RandomState
870+
Pseudo-random number generator state used for random sampling.
871+
'''
872+
873+
def __init__(self, n_iter=5, test_size=0.2, train_size=None,
874+
random_state=None):
875+
super(LabelShuffleSplit, self).__init__(
876+
n_iter=n_iter,
877+
test_size=test_size,
878+
train_size=train_size,
879+
random_state=random_state)
880+
881+
882+
def _iter_indices(self, X, y, labels):
883+
classes, label_indices = np.unique(labels, return_inverse=True)
884+
for label_train, label_test in super(
885+
LabelShuffleSplit, self)._iter_indices(X=classes):
886+
# these are the indices of classes in the partition
887+
# invert them into data indices
888+
889+
train = np.flatnonzero(np.in1d(label_indices, label_train))
890+
test = np.flatnonzero(np.in1d(label_indices, label_test))
891+
892+
yield train, test
892893

893894

894895
class StratifiedShuffleSplit(BaseShuffleSplit):
@@ -1018,6 +1019,72 @@ def get_n_splits(self, X=None, y=None, labels=None):
10181019
return self.n_iter
10191020

10201021

1022+
def _validate_shuffle_split_init(test_size, train_size):
1023+
if test_size is None and train_size is None:
1024+
raise ValueError('test_size and train_size can not both be None')
1025+
1026+
if test_size is not None:
1027+
if np.asarray(test_size).dtype.kind == 'f':
1028+
if test_size >= 1.:
1029+
raise ValueError(
1030+
'test_size=%f should be smaller '
1031+
'than 1.0 or be an integer' % test_size)
1032+
elif np.asarray(test_size).dtype.kind != 'i':
1033+
# int values are checked during split based on the input
1034+
raise ValueError("Invalid value for test_size: %r" % test_size)
1035+
1036+
if train_size is not None:
1037+
if np.asarray(train_size).dtype.kind == 'f':
1038+
if train_size >= 1.:
1039+
raise ValueError("train_size=%f should be smaller "
1040+
"than 1.0 or be an integer" % train_size)
1041+
elif ((np.asarray(test_size).dtype.kind == 'f') and
1042+
((train_size + test_size) > 1.)):
1043+
raise ValueError('The sum of test_size and train_size = %f, '
1044+
'should be smaller than 1.0. Reduce '
1045+
'test_size and/or train_size.' %
1046+
(train_size + test_size))
1047+
elif np.asarray(train_size).dtype.kind != 'i':
1048+
# int values are checked during split based on the input
1049+
raise ValueError("Invalid value for train_size: %r" % train_size)
1050+
1051+
1052+
def _validate_shuffle_split(n, test_size, train_size):
1053+
if ((test_size is not None) and (np.asarray(test_size).dtype.kind == 'i')
1054+
and (test_size >= n)):
1055+
raise ValueError('test_size=%d should be smaller '
1056+
'than the number of samples %d' % (test_size, n))
1057+
1058+
if ((train_size is not None) and (np.asarray(train_size).dtype.kind == 'i')
1059+
and (train_size >= n)):
1060+
raise ValueError("train_size=%d should be smaller "
1061+
"than the number of samples %d" % (train_size, n))
1062+
1063+
if np.asarray(test_size).dtype.kind == 'f':
1064+
n_test = ceil(test_size * n)
1065+
elif np.asarray(test_size).dtype.kind == 'i':
1066+
n_test = float(test_size)
1067+
1068+
if train_size is None:
1069+
n_train = n - n_test
1070+
else:
1071+
if np.asarray(train_size).dtype.kind == 'f':
1072+
n_train = floor(train_size * n)
1073+
else:
1074+
n_train = float(train_size)
1075+
1076+
if test_size is None:
1077+
n_test = n - n_train
1078+
1079+
if n_train + n_test > n:
1080+
raise ValueError('The sum of train_size and test_size = %d, '
1081+
'should be smaller than the number of '
1082+
'samples %d. Reduce test_size and/or '
1083+
'train_size.' % (n_train + n_test, n))
1084+
1085+
return int(n_train), int(n_test)
1086+
1087+
10211088
class PredefinedSplit(BaseCrossValidator):
10221089
"""Predefined split cross-validator
10231090

sklearn/model_selection/tests/test_split.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from itertools import combinations
99

1010
from sklearn.utils.testing import assert_true
11+
from sklearn.utils.testing import assert_false
1112
from sklearn.utils.testing import assert_equal
1213
from sklearn.utils.testing import assert_almost_equal
1314
from sklearn.utils.testing import assert_raises
@@ -30,6 +31,7 @@
3031
from sklearn.model_selection import LeavePOut
3132
from sklearn.model_selection import LeavePLabelOut
3233
from sklearn.model_selection import ShuffleSplit
34+
from sklearn.model_selection import LabelShuffleSplit
3335
from sklearn.model_selection import StratifiedShuffleSplit
3436
from sklearn.model_selection import PredefinedSplit
3537
from sklearn.model_selection import check_cv
@@ -566,6 +568,48 @@ def test_predefinedsplit_with_kfold_split():
566568
assert_array_equal(ps_test, kf_test)
567569

568570

571+
def test_label_shuffle_split():
572+
labels = [np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
573+
np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
574+
np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]),
575+
np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]),
576+
]
577+
578+
for l in labels:
579+
X = y = np.ones(len(l))
580+
n_iter = 6
581+
test_size = 1./3
582+
slo = LabelShuffleSplit(n_iter, test_size=test_size, random_state=0)
583+
584+
# Make sure the repr works
585+
repr(slo)
586+
587+
# Test that the length is correct
588+
assert_equal(slo.get_n_splits(X, y, labels=l), n_iter)
589+
590+
l_unique = np.unique(l)
591+
592+
for train, test in slo.split(X, y, labels=l):
593+
# First test: no train label is in the test set and vice versa
594+
l_train_unique = np.unique(l[train])
595+
l_test_unique = np.unique(l[test])
596+
assert_false(np.any(np.in1d(l[train], l_test_unique)))
597+
assert_false(np.any(np.in1d(l[test], l_train_unique)))
598+
599+
# Second test: train and test add up to all the data
600+
assert_equal(l[train].size + l[test].size, l.size)
601+
602+
# Third test: train and test are disjoint
603+
assert_array_equal(np.intersect1d(train, test), [])
604+
605+
# Fourth test:
606+
# unique train and test labels are correct, +- 1 for rounding error
607+
assert_true(abs(len(l_test_unique) -
608+
round(test_size * len(l_unique))) <= 1)
609+
assert_true(abs(len(l_train_unique) -
610+
round((1.0 - test_size) * len(l_unique))) <= 1)
611+
612+
569613
def test_leave_label_out_changing_labels():
570614
# Check that LeaveOneLabelOut and LeavePLabelOut work normally if
571615
# the labels variable is changed before calling split
@@ -790,7 +834,7 @@ def test_label_kfold():
790834
ideal_n_labels_per_fold = n_samples // n_folds
791835

792836
len(np.unique(labels))
793-
# Get the test fold indices from the test set indices of each fold
837+
# Get the test fold indices from the test set indices of each fold
794838
folds = np.zeros(n_samples)
795839
for i, (_, test) in enumerate(LabelKFold(n_folds).split(X, y, labels)):
796840
folds[test] = i
@@ -827,7 +871,7 @@ def test_label_kfold():
827871

828872
X = y = np.ones(n_samples)
829873

830-
# Get the test fold indices from the test set indices of each fold
874+
# Get the test fold indices from the test set indices of each fold
831875
folds = np.zeros(n_samples)
832876
for i, (_, test) in enumerate(LabelKFold(n_folds).split(X, y, labels)):
833877
folds[test] = i

0 commit comments

Comments
 (0)
0