|
30 | 30 | from sklearn.model_selection import KFold
|
31 | 31 | from sklearn.model_selection import StratifiedKFold
|
32 | 32 | from sklearn.model_selection import LabelKFold
|
| 33 | +from sklearn.model_selection import HomogeneousTimeSeriesCV |
33 | 34 | from sklearn.model_selection import LeaveOneOut
|
34 | 35 | from sklearn.model_selection import LeaveOneLabelOut
|
35 | 36 | from sklearn.model_selection import LeavePOut
|
@@ -970,6 +971,39 @@ def test_label_kfold():
|
970 | 971 | next, LabelKFold(n_folds=3).split(X, y, labels))
|
971 | 972 |
|
972 | 973 |
|
| 974 | +def test_homogeneous_time_series_cv(): |
| 975 | + X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]] |
| 976 | + |
| 977 | + # Should fail if there are more folds than samples |
| 978 | + assert_raises_regexp(ValueError, "Cannot have number of folds.*greater", |
| 979 | + next, |
| 980 | + HomogeneousTimeSeriesCV(n_folds=9).split(X)) |
| 981 | + |
| 982 | + homo_tscv = HomogeneousTimeSeriesCV(3) |
| 983 | + |
| 984 | + # Manually check that Homogeneous Time Series CV preserves the data |
| 985 | + # ordering on toy datasets |
| 986 | + splits = homo_tscv.split(X[:-1]) |
| 987 | + train, test = next(splits) |
| 988 | + assert_array_equal(train, [0, 1]) |
| 989 | + assert_array_equal(test, [2, 3]) |
| 990 | + |
| 991 | + train, test = next(splits) |
| 992 | + assert_array_equal(train, [0, 1, 2, 3]) |
| 993 | + assert_array_equal(test, [4, 5]) |
| 994 | + |
| 995 | + splits = HomogeneousTimeSeriesCV(3).split(X) |
| 996 | + train, test = next(splits) |
| 997 | + assert_array_equal(train, [0, 1, 2]) |
| 998 | + assert_array_equal(test, [3, 4]) |
| 999 | + |
| 1000 | + train, test = next(splits) |
| 1001 | + assert_array_equal(train, [0, 1, 2, 3, 4]) |
| 1002 | + assert_array_equal(test, [5, 6]) |
| 1003 | + |
| 1004 | + # Check get_n_splits returns the number of folds - 1 |
| 1005 | + assert_equal(2, homo_tscv.get_n_splits()) |
| 1006 | + |
973 | 1007 | def test_nested_cv():
|
974 | 1008 | # Test if nested cross validation works with different combinations of cv
|
975 | 1009 | rng = np.random.RandomState(0)
|
|
0 commit comments