|
25 | 25 | from sklearn.externals.six.moves import zip
|
26 | 26 |
|
27 | 27 |
|
28 |
| -class _PartitionTestGenerator(object): |
29 |
| - def __init__(self, mask_generator, y, n=None, indices=None): |
30 |
| - self._mask_generator = mask_generator |
31 |
| - if indices is None: |
32 |
| - indices = True |
33 |
| - else: |
34 |
| - warnings.warn("The indices parameter is deprecated and will be " |
35 |
| - "removed (assumed True) in 0.17", DeprecationWarning, |
36 |
| - stacklevel=1) |
37 |
| - if y is None: |
38 |
| - if n is None: |
39 |
| - raise ValueError("Must supply y or n parameters") |
40 |
| - if abs(n - int(n)) >= np.finfo('f').eps: |
41 |
| - raise ValueError("n must be an integer") |
42 |
| - self._n = int(n) |
43 |
| - self._indices = indices |
44 |
| - else: |
45 |
| - self._n = len(y) # TODO: Check for dict of arrays or DataFrame |
46 |
| - self._indices = True |
47 |
| - self._y = y |
48 |
| - |
49 |
| - |
50 |
| - def __iter__(self): |
51 |
| - indices = self._indices |
52 |
| - if indices: |
53 |
| - ind = np.arange(self._n) |
54 |
| - for test_index in self._mask_generator._iter_test_masks(self._y): |
55 |
| - train_index = np.logical_not(test_index) |
56 |
| - if indices: |
57 |
| - train_index = ind[train_index] |
58 |
| - test_index = ind[test_index] |
59 |
| - yield train_index, test_index |
60 |
| - |
61 |
| - |
62 |
| - |
63 | 28 | class _PartitionIterator(with_metaclass(ABCMeta)):
|
64 | 29 | """Base class for CV iterators where train_mask = ~test_mask
|
65 | 30 |
|
@@ -97,23 +62,21 @@ def indices(self):
|
97 | 62 |
|
98 | 63 | def __iter__(self):
|
99 | 64 | #TODO: deprecation warning
|
100 |
| - y = None |
101 |
| - self._pre_split_check(y) |
102 |
| - for train, test in _PartitionTestGenerator(self, y, n=self.n, indices=self.indices): |
103 |
| - yield train, test |
104 |
| -# indices = self._indices |
105 |
| -# if indices: |
106 |
| -# ind = np.arange(self.n) |
107 |
| -# for test_index in self._iter_test_masks(): |
108 |
| -# train_index = np.logical_not(test_index) |
109 |
| -# if indices: |
110 |
| -# train_index = ind[train_index] |
111 |
| -# test_index = ind[test_index] |
112 |
| -# yield train_index, test_index |
| 65 | + if self.n is None: |
| 66 | + raise ValueError("Cannot iterate dataless CV iterator") |
| 67 | + return self.split(None) |
113 | 68 |
|
114 | 69 | def split(self, y):
|
115 | 70 | self._pre_split_check(y)
|
116 |
| - return _PartitionTestGenerator(self, y, indices=self.indices) |
| 71 | + indices = self._indices |
| 72 | + if indices: |
| 73 | + ind = np.arange(self._sample_size(y)) |
| 74 | + for test_index in self._iter_test_masks(y): |
| 75 | + train_index = np.logical_not(test_index) |
| 76 | + if indices: |
| 77 | + train_index = ind[train_index] |
| 78 | + test_index = ind[test_index] |
| 79 | + yield train_index, test_index |
117 | 80 |
|
118 | 81 | def _pre_split_check(self, y):
|
119 | 82 | pass
|
|
0 commit comments