8000 Add test for homogeneous TSCV · scikit-learn/scikit-learn@1348e4a · GitHub
[go: up one dir, main page]

Skip to content

Search code, repositories, users, issues, pull requests...

Provide feedback

8000

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 1348e4a

Browse files
committed
Add test for homogeneous TSCV
1 parent bfd4a5e commit 1348e4a

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

sklearn/model_selection/tests/test_split.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from sklearn.model_selection import KFold
3131
from sklearn.model_selection import StratifiedKFold
3232
from sklearn.model_selection import LabelKFold
33+
from sklearn.model_selection import HomogeneousTimeSeriesCV
3334
from sklearn.model_selection import LeaveOneOut
3435
from sklearn.model_selection import LeaveOneLabelOut
3536
from sklearn.model_selection import LeavePOut
@@ -984,6 +985,39 @@ def test_label_kfold():
984985
next, LabelKFold(n_folds=3).split(X, y, labels))
985986

986987

988+
def test_homogeneous_time_series_cv():
989< 100B9 code class="diff-text syntax-highlighted-line addition">+
X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]]
990+
991+
# Should fail if there are more folds than samples
992+
assert_raises_regexp(ValueError, "Cannot have number of folds.*greater",
993+
next,
994+
HomogeneousTimeSeriesCV(n_folds=9).split(X))
995+
996+
homo_tscv = HomogeneousTimeSeriesCV(3)
997+
998+
# Manually check that Homogeneous Time Series CV preserves the data
999+
# ordering on toy datasets
1000+
splits = homo_tscv.split(X[:-1])
1001+
train, test = next(splits)
1002+
assert_array_equal(train, [0, 1])
1003+
assert_array_equal(test, [2, 3])
1004+
1005+
train, test = next(splits)
1006+
assert_array_equal(train, [0, 1, 2, 3])
1007+
assert_array_equal(test, [4, 5])
1008+
1009+
splits = HomogeneousTimeSeriesCV(3).split(X)
1010+
train, test = next(splits)
1011+
assert_array_equal(train, [0, 1, 2])
1012+
assert_array_equal(test, [3, 4])
1013+
1014+
train, test = next(splits)
1015+
assert_array_equal(train, [0, 1, 2, 3, 4])
1016+
assert_array_equal(test, [5, 6])
1017+
1018+
# Check get_n_splits returns the number of folds - 1
1019+
assert_equal(2, homo_tscv.get_n_splits())
1020+
9871021
def test_nested_cv():
9881022
# Test if nested cross validation works with different combinations of cv
9891023
rng = np.random.RandomState(0)

0 commit comments

Comments
 (0)
0