@@ -2184,42 +2184,43 @@ def _build_repr(self):
2184
2184
2185
2185
2186
2186
class GroupTimeSeriesSplit (TimeSeriesSplit ):
2187
+ def __init__ (self , n_splits = 5 , max_train_size = None ):
2188
+ super ().__init__ (n_splits )
2189
+ self .max_train_size = max_train_size
2187
2190
2188
-
2189
- def __init__ (self , n_splits = 5 , max_train_size = None ):
2190
- super ().__init__ (n_splits , shuffle = False , random_state = None )
2191
- self .max_train_size = max_train_size
2192
-
2193
-
2194
- def split (self , X , y = None , groups = None ):
2195
-
2196
- X , y , groups = indexable (X , y , groups )
2197
- n_samples = _num_samples (X )
2198
- n_splits = self .n_splits
2199
- n_folds = n_splits + 1
2200
- if n_folds > n_samples :
2201
- raise ValueError (
2202
- ("Cannot have number of folds ={0} greater"
2203
- " than the number of samples: {1}." ).format (n_folds ,n_samples ))
2204
- unique_groups = np .unique (groups )
2205
- groups_dict = {key : [] for key in unique_groups }
2206
- indices = np .arange (n_samples )
2207
- for i in indices :
2208
- temp = groups_dict .get (groups [i ])
2209
- temp .append (i )
2210
- groups_dict [groups [i ]] = temp
2211
- group_combinations = combinations (unique_groups , 2 )
2212
- for group_keys in list (group_combinations ):
2213
- index_list1 = groups_dict .get (group_keys [0 ])
2214
- index_list2 = groups_dict .get (group_keys [1 ])
2215
- sample_len = min (len (index_list1 ) ,len (index_list2 ))
2216
- test_size = (sample_len // n_folds )
2217
- test_starts = range (test_size + n_samples % n_folds ,
2218
- n_samples , test_size )
2219
- for test_start in test_starts :
2220
- if self .max_train_size and self .max_train_size < test_start :
2221
- yield (indices [test_start - self .max_train_size :test_start ],
2222
- indices [test_start :test_start + test_size ])
2223
- else :
2224
- yield (indices [:test_start ],
2225
- indices [test_start :test_start + test_size ])
2191
+ def split (self , X , y = None , groups = None ):
2192
+ X , y , groups = indexable (X , y , groups )
2193
+ n_samples = _num_samples (X )
2194
+ n_splits = self .n_splits
2195
+ n_folds = n_splits + 1
2196
+ indices = np .arange (n_samples )
2197
+ if n_folds > n_samples :
2198
+ raise ValueError (
2199
+ ("Cannot have number of folds ={0} greater"
2200
+ " than the number of samples: {1}." ).format (n_folds , n_samples ))
2201
+ else :
2202
+ test_size = (n_samples // n_folds )
2203
+ test_starts = range (test_size + n_samples % n_folds , n_samples ,
2204
+ test_size )
2205
+ for test_start in test_starts :
2206
+ if self .max_train_size and self .max_train_size < test_start :
2207
+ train_array = indices [test_start -
2208
+ self .max_train_size :test_start ]
2209
+ test_array = indices [test_start :test_start + test_size ]
2210
+ train_group = [groups [i ] for i in train_array ]
2211
+ for i in test_array :
2212
+ if groups [i ] in train_group :
2213
+ test_array = test_array [test_array != i ]
2214
+ if test_array .size <= 1 :
2215
+ continue
2216
+ yield (train_array , test_array )
2217
+ else :
2218
+ train_array = indices [:test_start ]
2219
+ test_array = indices [test_start :test_start + test_size ]
2220
+ train_group = [groups [i ] for i in train_array ]
2221
+ for i in test_array :
2222
+ if groups [i ] in train_group :
2223
+ test_array = test_array [test_array != i ]
2224
+ if test_array .size <= 1 :
2225
+ continue
2226
+ yield (train_array , test_array )
0 commit comments