8000 [WIP] fixes #14257 - New Approach plus flake8 fixes · getgaurav2/scikit-learn@10df234 · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit 10df234

Browse files
author
Gaurav Chawla
committed
[WIP] fixes scikit-learn#14257 - New Approach plus flake8 fixes
1 parent 6182361 commit 10df234

File tree

1 file changed

+39
-38
lines changed

1 file changed

+39
-38
lines changed

sklearn/model_selection/_split.py

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2184,42 +2184,43 @@ def _build_repr(self):
21842184

21852185

21862186
class GroupTimeSeriesSplit(TimeSeriesSplit):
2187+
def __init__(self, n_splits=5, max_train_size=None):
2188+
super().__init__(n_splits)
2189+
self.max_train_size = max_train_size
21872190

2188-
2189-
def __init__(self, n_splits=5, max_train_size=None):
2190-
super().__init__(n_splits, shuffle=False, random_state=None)
2191-
self.max_train_size = max_train_size
2192-
2193-
2194-
def split(self, X, y=None, groups=None):
2195-
2196-
X, y, groups = indexable(X, y, groups)
2197-
n_samples = _num_samples(X)
2198-
n_splits = self.n_splits
2199-
n_folds = n_splits + 1
2200-
if n_folds > n_samples:
2201-
raise ValueError(
2202-
("Cannot have number of folds ={0} greater"
2203-
" than the number of samples: {1}.").format(n_folds,n_samples))
2204-
unique_groups = np.unique(groups)
2205-
groups_dict = {key: [] for key in unique_groups}
2206-
indices = np.arange(n_samples)
2207-
for i in indices:
2208-
temp = groups_dict.get(groups[i])
2209-
temp.append(i)
2210-
groups_dict[groups[i]] = temp
2211-
group_combinations = combinations(unique_groups, 2)
2212-
for group_keys in list(group_combinations):
2213-
index_list1 = groups_dict.get(group_keys[0])
2214-
index_list2 = groups_dict.get(group_keys[1])
2215-
sample_len = min(len(index_list1) ,len(index_list2))
2216-
test_size = (sample_len // n_folds)
2217-
test_starts = range(test_size + n_samples % n_folds,
2218-
n_samples, test_size)
2219-
for test_start in test_starts:
2220-
if self.max_train_size and self.max_train_size < test_start:
2221-
yield (indices[test_start - self.max_train_size:test_start],
2222-
indices[test_start:test_start + test_size])
2223-
else:
2224-
yield (indices[:test_start],
2225-
indices[test_start:test_start + test_size])
2191+
def split(self, X, y=None, groups=None):
2192+
X, y, groups = indexable(X, y, groups)
2193+
n_samples = _num_samples(X)
2194+
n_splits = self.n_splits
2195+
n_folds = n_splits + 1
2196+
indices = np.arange(n_samples)
2197+
if n_folds > n_samples:
2198+
raise ValueError(
2199+
("Cannot have number of folds ={0} greater"
2200+
" than the number of samples: {1}.").format(n_folds, n_samples))
2201+
else:
2202+
test_size = (n_samples // n_folds)
2203+
test_starts = range(test_size + n_samples % n_folds, n_samples,
2204+
test_size)
2205+
for test_start in test_starts:
2206+
if self.max_train_size and self.max_train_size < test_start:
2207+
train_array = indices[test_start -
2208+
self.max_train_size:test_start]
2209+
test_array = indices[test_start:test_start + test_size]
2210+
train_group = [groups[i] for i in train_array]
2211+
for i in test_array:
2212+
if groups[i] in train_group:
2213+
test_array = test_array[test_array != i]
2214+
if test_array.size <= 1:
2215+
continue
2216+
yield(train_array, test_array)
2217+
else:
2218+
train_array = indices[:test_start]
2219+
test_array = indices[test_start:test_start + test_size]
2220+
train_group = [groups[i] for i in train_array]
2221+
for i in test_array:
2222+
if groups[i] in train_group:
2223+
test_array = test_array[test_array != i]
2224+
if test_array.size <= 1:
2225+
continue
2226+
yield(train_array, test_array)

0 commit comments

Comments
 (0)
0