8000 Rename `TimeSeriesCV` to `TimeSeriesSplit` (#7245) · TomDLT/scikit-learn@9f457f5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9f457f5

Browse files
yenchenlinTomDLT
authored andcommitted
Rename TimeSeriesCV to TimeSeriesSplit (scikit-learn#7245)
* rename TimeSeriesCV to TimeSeriesSplit * Add TimeSeriesSplit * Add whats new
1 parent f1198c0 commit 9f457f5

File tree

6 files changed

+27
-20
lines changed

6 files changed

+27
-20
lines changed

doc/modules/classes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ Splitter Classes
174174
model_selection.LabelShuffleSplit
175175
model_selection.StratifiedShuffleSplit
176176
model_selection.PredefinedSplit
177+
model_selection.TimeSeriesSplit
177178

178179
Splitter Functions
179180
------------------

doc/modules/cross_validation.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -533,13 +533,13 @@ between training and testing instances (yielding poor estimates of
533533
generalisation error) on time series data. Therefore, it is very important
534534
to evaluate our model for time series data on the "future" observations
535535
least like those that are used to train the model. To achieve this, one
536-
solution is provided by :class:`TimeSeriesCV`.
536+
solution is provided by :class:`TimeSeriesSplit`.
537537

538538

539-
TimeSeriesCV
539+
TimeSeriesSplit
540540
-----------------------
541541

542-
:class:`TimeSeriesCV` is a variation of *k-fold* which
542+
:class:`TimeSeriesSplit` is a variation of *k-fold* which
543543
returns first :math:`k` folds as train set and the :math:`(k+1)` th
544544
fold as test set. Note that unlike standard cross-validation methods,
545545
successive training sets are supersets of those that come before them.
@@ -551,13 +551,13 @@ that are observed at fixed time intervals.
551551

552552
Example of 3-split time series cross-validation on a dataset with 6 samples::
553553

554-
>>> from sklearn.model_selection import TimeSeriesCV
554+
>>> from sklearn.model_selection import TimeSeriesSplit
555555

556556
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]])
557557
>>> y = np.array([1, 2, 3, 4, 5, 6])
558-
>>> tscv = TimeSeriesCV(n_splits=3)
558+
>>> tscv = TimeSeriesSplit(n_splits=3)
559559
>>> print(tscv) # doctest: +NORMALIZE_WHITESPACE
560-
TimeSeriesCV(n_splits=3)
560+
TimeSeriesSplit(n_splits=3)
561561
>>> for train, test in tscv.split(X):
562562
... print("%s %s" % (train, test))
563563
[0 1 2] [3]

doc/whats_new.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,12 @@ New features
141141
<https://github.com/scikit-learn/scikit-learn/pull/6954>`_) by `Nelson
142142
Liu`_
143143

144+
- Added new cross-validation splitter
145+
:class:`model_selection.TimeSeriesSplit` to handle time series data.
146+
(`#6586
147+
<https://github.com/scikit-learn/scikit-learn/pull/6586>`_) by `YenChen
148+
Lin`_
149+
144150
Enhancements
145151
............
146152

sklearn/model_selection/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from ._split import KFold
33
from ._split import LabelKFold
44
from ._split import StratifiedKFold
5-
from ._split import TimeSeriesCV
5+
from ._split import TimeSeriesSplit
66
from ._split import LeaveOneLabelOut
77
from ._split import LeaveOneOut
88
from ._split import LeavePLabelOut
@@ -28,7 +28,7 @@
2828

2929
__all__ = ('BaseCrossValidator',
3030
'GridSearchCV',
31-
'TimeSeriesCV',
31+
'TimeSeriesSplit',
3232
'KFold',
3333
'LabelKFold',
3434
'LabelShuffleSplit',

sklearn/model_selection/_split.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,7 @@ def split(self, X, y, labels=None):
635635
return super(StratifiedKFold, self).split(X, y, labels)
636636

637637

638-
class TimeSeriesCV(_BaseKFold):
638+
class TimeSeriesSplit(_BaseKFold):
639639
"""Time Series cross-validator
640640
641641
Provides train/test indices to split time series data samples
@@ -659,12 +659,12 @@ class TimeSeriesCV(_BaseKFold):
659659
660660
Examples
661661
--------
662-
>>> from sklearn.model_selection import TimeSeriesCV
662+
>>> from sklearn.model_selection import TimeSeriesSplit
663663
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
664664
>>> y = np.array([1, 2, 3, 4])
665-
>>> tscv = TimeSeriesCV(n_splits=3)
665+
>>> tscv = TimeSeriesSplit(n_splits=3)
666666
>>> print(tscv) # doctest: +NORMALIZE_WHITESPACE
667-
TimeSeriesCV(n_splits=3)
667+
TimeSeriesSplit(n_splits=3)
668668
>>> for train_index, test_index in tscv.split(X):
669669
... print("TRAIN:", train_index, "TEST:", test_index)
670670
... X_train, X_test = X[train_index], X[test_index]
@@ -681,9 +681,9 @@ class TimeSeriesCV(_BaseKFold):
681681
where ``n_samples`` is the number of samples.
682682
"""
683683
def __init__(self, n_splits=3):
684-
super(TimeSeriesCV, self).__init__(n_splits,
685-
shuffle=False,
686-
random_state=None)
684+
super(TimeSeriesSplit, self).__init__(n_splits,
685+
shuffle=False,
686+
random_state=None)
687687

688688
def split(self, X, y=None, labels=None):
689689
"""Generate indices to split data into training and test set.

sklearn/model_selection/tests/test_split.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from sklearn.model_selection import KFold
3131
from sklearn.model_selection import StratifiedKFold
3232
from sklearn.model_selection import LabelKFold
33-
from sklearn.model_selection import TimeSeriesCV
33+
from sklearn.model_selection import TimeSeriesSplit
3434
from sklearn.model_selection import LeaveOneOut
3535
from sklearn.model_selection import LeaveOneLabelOut
3636
from sklearn.model_selection import LeavePOut
@@ -1004,9 +1004,9 @@ def test_time_series_cv():
10041004
# Should fail if there are more folds than samples
10051005
assert_raises_regexp(ValueError, "Cannot have number of folds.*greater",
10061006
next,
1007-
TimeSeriesCV(n_splits=7).split(X))
1007+
TimeSeriesSplit(n_splits=7).split(X))
10081008

1009-
tscv = TimeSeriesCV(2)
1009+
tscv = TimeSeriesSplit(2)
10101010

10111011
# Manually check that Time Series CV preserves the data
10121012
# ordering on toy datasets
@@ -1019,7 +1019,7 @@ def test_time_series_cv():
10191019
assert_array_equal(train, [0, 1, 2, 3])
10201020
assert_array_equal(test, [4, 5])
10211021

1022-
splits = TimeSeriesCV(2).split(X)
1022+
splits = TimeSeriesSplit(2).split(X)
10231023

10241024
train, test = next(splits)
10251025
assert_array_equal(train, [0, 1, 2])
@@ -1030,7 +1030,7 @@ def test_time_series_cv():
10301030
assert_array_equal(test, [5, 6])
10311031

10321032
# Check get_n_splits returns the correct number of splits
1033-
splits = TimeSeriesCV(2).split(X)
1033+
splits = TimeSeriesSplit(2).split(X)
10341034
n_splits_actual = len(list(splits))
10351035
assert_equal(n_splits_actual, tscv.get_n_splits())
10361036
assert_equal(n_splits_actual, 2)

0 commit comments

Comments
 (0)
0