8000 [MRG] Add homogeneous time series cross validation (#6586) · scikit-learn/scikit-learn@234d256 · GitHub
[go: up one dir, main page]

Skip to content

Commit 234d256

Browse files
yenchenlinjnothman
authored andcommitted
[MRG] Add homogeneous time series cross validation (#6586)
1 parent 040a766 commit 234d256

File tree

4 files changed

+177
-0
lines changed

4 files changed

+177
-0
lines changed

doc/modules/cross_validation.rst

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,50 @@ See also
521521
stratified splits, *i.e* which creates splits by preserving the same
522522
percentage for each target class as in the complete set.
523523

524+
Cross validation of time series data
525+
====================================
526+
527+
Time series data is characterised by the correlation between observations
528+
that are near in time (*autocorrelation*). However, classical
529+
cross-validation techniques such as :class:`KFold` and
530+
:class:`ShuffleSplit` assume the samples are independent and
531+
identically distributed, and would result in unreasonable correlation
532+
between training and testing instances (yielding poor estimates of
533+
generalisation error) on time series data. Therefore, it is very important
534+
to evaluate our model for time series data on the "future" observations
535+
least like those that are used to train the model. To achieve this, one
536+
solution is provided by :class:`TimeSeriesCV`.
537+
538+
539+
TimeSeriesCV
540+
-----------------------
541+
542+
:class:`TimeSeriesCV` is a variation of *k-fold* which
543+
returns first :math:`k` folds as train set and the :math:`(k+1)` th
544+
fold as test set. Note that unlike standard cross-validation methods,
545+
successive training sets are supersets of those that come before them.
546+
Also, it adds all surplus data to the first training partition, which
547+
is always used to train the model.
548+
549+
This class can be used to cross-validate time series data samples
550+
that are observed at fixed time intervals.
551+
552+
Example of 3-split time series cross-validation on a dataset with 6 samples::
553+
554+
>>> from sklearn.model_selection import TimeSeriesCV
555+
556+
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]])
557+
>>> y = np.array([1, 2, 3, 4, 5, 6])
558+
>>> tscv = TimeSeriesCV(n_splits=3)
559+
>>> print(tscv) # doctest: +NORMALIZE_WHITESPACE
560+
TimeSeriesCV(n_splits=3)
561+
>>> for train, test in tscv.split(X):
562+
... print("%s %s" % (train, test))
563+
[0 1 2] [3]
564+
[0 1 2 3] [4]
565+
[0 1 2 3 4] [5]
566+
567+
524568
A note on shuffling
525569
===================
526570

sklearn/model_selection/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ._split import KFold
33
from ._split import LabelKFold
44
from ._split import StratifiedKFold
5+
from ._split import TimeSeriesCV
56
from ._split import LeaveOneLabelOut
67
from ._split import LeaveOneOut
78
from ._split import LeavePLabelOut
@@ -27,6 +28,7 @@
2728

2829
__all__ = ('BaseCrossValidator',
2930
'GridSearchCV',
31+
'TimeSeriesCV',
3032
'KFold',
3133
'LabelKFold',
3234
'LabelShuffleSplit',

sklearn/model_selection/_split.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,98 @@ def split(self, X, y, labels=None):
635635
return super(StratifiedKFold, self).split(X, y, labels)
636636

637637

638+
class TimeSeriesCV(_BaseKFold):
639+
"""Time Series cross-validator
640+
641+
Provides train/test indices to split time series data samples
642+
that are observed at fixed time intervals, in train/test sets.
643+
In each split, test indices must be higher than before, and thus shuffling
644+
in cross validator is inappropriate.
645+
646+
This cross-validation object is a variation of :class:`KFold`.
647+
In the kth split, it returns first k folds as train set and the
648+
(k+1)th fold as test set.
649+
650+
Note that unlike standard cross-validation methods, successive
651+
training sets are supersets of those that come before them.
652+
653+
Read more in the :ref:`User Guide <cross_validation>`.
654+
655+
Parameters
656+
----------
657+
n_splits : int, default=3
658+
Number of splits. Must be at least 1.
659+
660+
Examples
661+
--------
662+
>>> from sklearn.model_selection import TimeSeriesCV
663+
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
664+
>>> y = np.array([1, 2, 3, 4])
665+
>>> tscv = TimeSeriesCV(n_splits=3)
666+
>>> print(tscv) # doctest: +NORMALIZE_WHITESPACE
667+
TimeSeriesCV(n_splits=3)
668+
>>> for train_index, test_index in tscv.split(X):
669+
... print("TRAIN:", train_index, "TEST:", test_index)
670+
... X_train, X_test = X[train_index], X[test_index]
671+
... y_train, y_test = y[train_index], y[test_index]
672+
TRAIN: [0] TEST: [1]
673+
TRAIN: [0 1] TEST: [2]
674+
TRAIN: [0 1 2] TEST: [3]
675+
676+
Notes
677+
-----
678+
The training set has size ``i * n_samples // (n_splits + 1)
679+
+ n_samples % (n_splits + 1)`` in the ``i``th split,
680+
with a test set of size ``n_samples//(n_splits + 1)``,
681+
where ``n_samples`` is the number of samples.
682+
"""
683+
def __init__(self, n_splits=3):
684+
super(TimeSeriesCV, self).__init__(n_splits,
685+
shuffle=False,
686+
random_state=None)
687+
688+
def split(self, X, y=None, labels=None):
689+
"""Generate indices to split data into training and test set.
690+
691+
Parameters
692+
----------
693+
X : array-like, shape (n_samples, n_features)
694+
Training data, where n_samples is the number of samples
695+
and n_features is the number of features.
696+
697+
y : array-like, shape (n_samples,)
698+
The target variable for supervised learning problems.
699+
700+
labels : array-like, with shape (n_samples,), optional
701+
Group labels for the samples used while splitting the dataset into
702+
train/test set.
703+
704+
Returns
705+
-------
706+
train : ndarray
707+
The training set indices for that split.
708+
709+
test : ndarray
710+
The testing set indices for that split.
711+
"""
712+
X, y, labels = indexable(X, y, labels)
713+
n_samples = _num_samples(X)
714+
n_splits = self.n_splits
715+
n_folds = n_splits + 1
716+
if n_folds > n_samples:
717+
raise ValueError(
718+
("Cannot have number of folds ={0} greater"
719+
" than the number of samples: {1}.").format(n_folds,
720+
n_samples))
721+
indices = np.arange(n_samples)
722+
test_size = (n_samples // n_folds)
723+
test_starts = range(test_size + n_samples % n_folds,
724+
n_samples, test_size)
725+
for test_start in test_starts:
726+
yield (indices[:test_start],
727+
indices[test_start:test_start + test_size])
728+
729+
638730
class LeaveOneLabelOut(BaseCrossValidator):
639731
"""Leave One Label Out cross-validator
640732

sklearn/model_selection/tests/test_split.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from sklearn.model_selection import KFold
3131
from sklearn.model_selection import StratifiedKFold
3232
from sklearn.model_selection import LabelKFold
33+
from sklearn.model_selection import TimeSeriesCV
3334
from sklearn.model_selection import LeaveOneOut
3435
from sklearn.model_selection import LeaveOneLabelOut
3536
from sklearn.model_selection import LeavePOut
@@ -997,6 +998,44 @@ def test_label_kfold():
997998
next, LabelKFold(n_splits=3).split(X, y, labels))
998999

9991000

1001+
def test_time_series_cv():
1002+
X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]]
1003+
1004+
# Should fail if there are more folds than samples
1005+
assert_raises_regexp(ValueError, "Cannot have number of folds.*greater",
1006+
next,
1007+
TimeSeriesCV(n_splits=7).split(X))
1008+
1009+
tscv = TimeSeriesCV(2)
1010+
1011+
# Manually check that Time Series CV preserves the data
1012+
# ordering on toy datasets
1013+
splits = tscv.split(X[:-1])
1014+
train, test = next(splits)
1015+
assert_array_equal(train, [0, 1])
1016+
assert_array_equal(test, [2, 3])
1017+
1018+
train, test = next(splits)
1019+
assert_array_equal(train, [0, 1, 2, 3])
1020+
assert_array_equal(test, [4, 5])
1021+
1022+
splits = TimeSeriesCV(2).split(X)
1023+
1024+
train, test = next(splits)
1025+
assert_array_equal(train, [0, 1, 2])
1026+
assert_array_equal(test, [3, 4])
1027+
1028+
train, test = next(splits)
1029+
assert_array_equal(train, [0, 1, 2, 3, 4])
1030+
assert_array_equal(test, [5, 6])
1031+
1032+
# Check get_n_splits returns the correct number of splits
1033+
splits = TimeSeriesCV(2).split(X)
1034+
n_splits_actual = len(list(splits))
1035+
assert_equal(n_splits_actual, tscv.get_n_splits())
1036+
assert_equal(n_splits_actual, 2)
1037+
1038+
10001039
def test_nested_cv():
10011040
# Test if nested cross validation works with different combinations of cv
10021041
rng = np.random.RandomState(0)

0 commit comments

Comments
 (0)
0