8000 LeaveOut and LeaveLabelOut working with split · scikit-learn/scikit-learn@0945b5b · GitHub
[go: up one dir, main page]

Skip to content

Commit 0945b5b

Browse files
pignacioraghavrv
authored andcommitted
LeaveOut and LeaveLabelOut working with split
1 parent 9d67b0e commit 0945b5b

File tree

1 file changed

+40
-25
lines changed

1 file changed

+40
-25
lines changed

sklearn/model_selection/partition.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ class LeaveOneOut(_PartitionIterator):
150150
domain-specific stratification of the dataset.
151151
"""
152152

153-
def _iter_test_indices(self):
154-
return range(self.n)
153+
def _iter_test_indices(self, y):
154+
return range(self._sample_size(y))
155155

156156
def __repr__(self):
157157
return '%s.%s(n=%i)' % (
@@ -161,6 +161,7 @@ def __repr__(self):
161161
)
162162

163163
def __len__(self):
164+
# TODO: remove?
164165
return self.n
165166

166167

@@ -208,12 +209,14 @@ class LeavePOut(_PartitionIterator):
208209
TRAIN: [0 1] TEST: [2 3]
209210
"""
210211

211-
def __init__(self, n, p, indices=None):
212+
def __init__(self, n=None, p=None, indices=None):
212213
super(LeavePOut, self).__init__(n, indices)
214+
if p is None:
215+
raise ValueError("LeavePOut: must supply p")
213216
self.p = p
214217

215-
def _iter_test_indices(self):
216-
for comb in combinations(range(self.n), self.p):
218+
def _iter_test_indices(self, y):
219+
for comb in combinations(range(self._sample_size(y)), self.p):
217220
yield np.array(comb)
218221

219222
def __repr__(self):
@@ -225,6 +228,7 @@ def __repr__(self):
225228
)
226229

227230
def __len__(self):
231+
# TODO: remove?
228232
return int(factorial(self.n) / factorial(self.n - self.p)
229233
/ factorial(self.p))
230234

@@ -512,16 +516,20 @@ class LeaveOneLabelOut(_PartitionIterator):
512516
513517
"""
514518

515-
def __init__(self, labels, indices=None):
516-
super(LeaveOneLabelOut, self).__init__(len(labels), indices)
517-
# We make a copy of labels to avoid side-effects during iteration
518-
self.labels = np.array(labels, copy=True)
519-
self.unique_labels = np.unique(labels)
520-
self.n_unique_labels = len(self.unique_labels)
519+
def __init__(self, labels=None, indices=None):
520+
n = None if labels is None else len(labels)
521+
super(LeaveOneLabelOut, self).__init__(n, indices)
522+
self.labels = labels
521523

522-
def _iter_test_masks(self):
523-
for i in self.unique_labels:
524-
yield self.labels == i
524+
def _iter_test_masks(self, y):
525+
labels = self.labels if y is None else y
526+
# We make a copy of labels to avoid side-effects during iteration
527+
labels = np.array(labels, copy=True)
528+
unique_labels = np.unique(labels)
529+
n_unique_labels = len(unique_labels)
530+
for i in unique_labels:
531+
print("yielding", labels == i)
532+
yield labels == i
525533

526534
def __repr__(self):
527535
return '%s.%s(labels=%s)' % (
@@ -531,6 +539,7 @@ def __repr__(self):
531539
)
532540

533541
def __len__(self):
542+
# TODO: remove?
534543
return self.n_unique_labels
535544

536545

@@ -585,21 +594,27 @@ class LeavePLabelOut(_PartitionIterator):
585594
[5 6]] [1] [2 1]
586595
"""
587 57AE 596

588-
def __init__(self, labels, p, indices=None):
589-
# We make a copy of labels to avoid side-effects during iteration
590-
super(LeavePLabelOut, self).__init__(len(labels), indices)
591-
self.labels = np.array(labels, copy=True)
592-
self.unique_labels = np.unique(labels)
593-
self.n_unique_labels = len(self.unique_labels)
597+
def __init__(self, labels=None, p=None, indices=None):
598+
n = None if labels is None else len(labels)
599+
super(LeavePLabelOut, self).__init__(n, indices)
600+
if p is None:
601+
raise ValueError("LeavePLabelOut: must supply p")
602+
594603
self.p = p
604+
self.labels = labels
595605

596-
def _iter_test_masks(self):
597-
comb = combinations(range(self.n_unique_labels), self.p)
606+
def _iter_test_masks(self, y):
607+
labels = self.labels if y is None else y
608+
# We make a copy of labels to avoid side-effects during iteration
609+
labels = np.array(labels, copy=True)
610+
unique_labels = np.unique(labels)
611+
n_unique_labels = len(unique_labels)
612+
comb = combinations(range(n_unique_labels), self.p)
598613
for idx in comb:
599-
test_index = self._empty_mask()
614+
test_index = self._empty_mask(labels)
600615
idx = np.array(idx)
601-
for l in self.unique_labels[idx]:
602-
test_index[self.labels == l] = True
616+
for l in unique_labels[idx]:
617+
test_index[labels == l] = True
603618
yield test_index
604619

605620
def __repr__(self):

0 commit comments

Comments
 (0)
0