8000 doc(cv): add TimeSeriesCV group split example · scikit-learn/scikit-learn@2a30d5a · GitHub
[go: up one dir, main page]

Skip to content

Commit 2a30d5a

Browse files
8000 tczhaotczhao
committed
doc(cv): add TimeSeriesCV group split example
1 parent 9ce813d commit 2a30d5a

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

sklearn/model_selection/_split.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,15 @@ class TimeSeriesSplit(_BaseKFold):
815815
TRAIN: [0 1 2 3] TEST: [6 7]
816816
TRAIN: [0 1 2 3 4 5] TEST: [8 9]
817817
TRAIN: [0 1 2 3 4 5 6 7] TEST: [10 11]
818+
>>> # Split using group
819+
>>> tscv = TimeSeriesSplit(n_splits=2)
820+
>>> groups = np.array([1, 1, 2, 3, 4, 4])
821+
>>> for train_index, test_index in tscv.split(X, groups=groups):
822+
... print("TRAIN:", train_index, "TEST:", test_index)
823+
... X_train, X_test = X[train_index], X[test_index]
824+
... y_train, y_test = y[train_index], y[test_index]
825+
TRAIN: [0 1 2] TEST: [3]
826+
TRAIN: [0 1 2 3] TEST: [4 5]
818827
819828
Notes
820829
-----

0 commit comments

Comments
 (0)
0