|
30 | 30 | from sklearn.model_selection import KFold
|
31 | 31 | from sklearn.model_selection import StratifiedKFold
|
32 | 32 | from sklearn.model_selection import LabelKFold
|
| 33 | +from sklearn.model_selection import HomogeneousTimeSeriesCV |
33 | 34 | from sklearn.model_selection import LeaveOneOut
|
34 | 35 | from sklearn.model_selection import LeaveOneLabelOut
|
35 | 36 | from sklearn.model_selection import LeavePOut
|
@@ -984,6 +985,39 @@ def test_label_kfold():
|
984 | 985 | next, LabelKFold(n_folds=3).split(X, y, labels))
|
985 | 986 |
|
986 | 987 |
|
| 988 | +def test_homogeneous_time_series_cv(): |
| 989 | + X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]] |
| 990 | + |
| 991 | + # Should fail if there are more folds than samples |
| 992 | + assert_raises_regexp(ValueError, "Cannot have number of folds.*greater", |
| 993 | + next, |
| 994 | + HomogeneousTimeSeriesCV(n_folds=9).split(X)) |
| 995 | + |
| 996 | + homo_tscv = HomogeneousTimeSeriesCV(3) |
| 997 | + |
| 998 | + # Manually check that Homogeneous Time Series CV preserves the data |
| 999 | + # ordering on toy datasets |
| 1000 | + splits = homo_tscv.split(X[:-1]) |
| 1001 | + train, test = next(splits) |
| 1002 | + assert_array_equal(train, [0, 1]) |
| 1003 | + assert_array_equal(test, [2, 3]) |
| 1004 | + |
| 1005 | + train, test = next(splits) |
| 1006 | + assert_array_equal(train, [0, 1, 2, 3]) |
| 1007 | + assert_array_equal(test, [4, 5]) |
| 1008 | + |
| 1009 | + splits = HomogeneousTimeSeriesCV(3).split(X) |
| 1010 | + train, test = next(splits) |
| 1011 | + assert_array_equal(train, [0, 1, 2]) |
| 1012 | + assert_array_equal(test, [3, 4]) |
| 1013 | + |
| 1014 | + train, test = next(splits) |
| 1015 | + assert_array_equal(train, [0, 1, 2, 3, 4]) |
| 1016 | + assert_array_equal(test, [5, 6]) |
| 1017 | + |
| 1018 | + # Check get_n_splits returns the number of folds - 1 |
| 1019 | + assert_equal(2, homo_tscv.get_n_splits()) |
| 1020 | + |
987 | 1021 | def test_nested_cv():
|
988 | 1022 | # Test if nested cross validation works with different combinations of cv
|
989 | 1023 | rng = np.random.RandomState(0)
|
|
0 commit comments