8000 Add test for homogeneous TSCV · scikit-learn/scikit-learn@1348e4a · GitHub
[go: up one dir, main page]

Skip to content

Commit 1348e4a

Browse files
committed
Add test for homogeneous TSCV
1 parent bfd4a5e commit 1348e4a

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

sklearn/model_selection/tests/test_split.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from sklearn.model_selection import KFold
3131
from sklearn.model_selection import StratifiedKFold
3232
from sklearn.model_selection import LabelKFold
33+
from sklearn.model_selection import HomogeneousTimeSeriesCV
3334
from sklearn.model_selection import LeaveOneOut
3435
from sklearn.model_selection import LeaveOneLabelOut
3536
from sklearn.model_selection import LeavePOut
@@ -984,6 +985,39 @@ def test_label_kfold():
984985
next, LabelKFold(n_folds=3).split(X, y, labels))
985986

986987

988+
def test_homogeneous_time_series_cv():
989+
X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]]
990+
991+
# Should fail if there are more folds than samples
992+
assert_raises_regexp(ValueError, "Cannot have number of folds.*greater",
993+
next,
994+
HomogeneousTimeSeriesCV(n_folds=9).split(X))
995+
996+
homo_tscv = HomogeneousTimeSeriesCV(3)
997+
998+
# Manually check that Homogeneous Time Series CV preserves the data
999+
# ordering on toy datasets
1000+
splits = homo_tscv.split(X[:-1])
1001+
train, test = next(splits)
1002+
assert_array_equal(train, [0, 1])
1003+
assert_array_equal(test, [2, 3])
1004+
1005+
train, test = next(splits)
1006+
assert_array_equal(train, [0, 1, 2, 3])
1007+
assert_array_equal(test, [4, 5])
1008+
1009+
splits = HomogeneousTimeSeriesCV(3).split(X)
1010+
train, test = next(splits)
1011+
assert_array_equal(train, [0, 1, 2])
1012+
assert_array_equal(test, [3, 4])
1013+
1014+
train, test = next(splits)
1015+
assert_array_equal(train, [0, 1, 2, 3, 4])
1016+
assert_array_equal(test, [5, 6])
1017+
1018+
# Check get_n_splits returns the number of folds - 1
1019+
assert_equal(2, homo_tscv.get_n_splits())
1020+
9871021
def test_nested_cv():
9881022
# Test if nested cross validation works with different combinations of cv
9891023
rng = np.random.RandomState(0)

0 commit comments

Comments
 (0)
0