8000 RollingWindow cross-validation · scikit-learn/scikit-learn@698d723 · GitHub
[go: up one dir, main page]

Skip to content

Commit 698d723

Browse files
author
x0l
committed
RollingWindow cross-validation
A cross-validation strategy for timeseries, see http://robjhyndman.com/hyndsight/tscvexample Initial commit, tests and unfinished docs
1 parent 45e82c3 commit 698d723

File tree

4 files changed

+157
-12
lines changed

4 files changed

+157
-12
lines changed

doc/modules/classes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ Classes
166166
cross_validation.LeaveOneOut
167167
cross_validation.LeavePLabelOut
168168
cross_validation.LeavePOut
169+
cross_validation.RollingWindow
169170
cross_validation.StratifiedKFold
170171
cross_validation.ShuffleSplit
171172
cross_validation.StratifiedShuffleSplit

doc/modules/cross_validation.rst

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ training set::
232232
[0 1 2] [3]
233233

234234

235-
Potential users of LOO for model selection should weigh a few known caveats.
235+
Potential users of LOO for model selection should weigh a few known caveats.
236236
When compared with :math:`k`-fold cross validation, one builds :math:`n` models
237237
from :math:`n` samples instead of :math:`k` models, where :math:`n > k`.
238238
Moreover, each is trained on :math:`n - 1` samples rather than
@@ -246,10 +246,10 @@ the :math:`n` samples are used to build each model, models constructed from
246246
folds are virtually identical to each other and to the model built from the
247247
entire training set.
248248

249-
However, if the learning curve is steep for the training size in question,
249+
However, if the learning curve is steep for the training size in question,
250250
then 5- or 10- fold cross validation can overestimate the generalization error.
251251

252-
As a general rule, most authors, and empirical evidence, suggest that 5- or 10-
252+
As a general rule, most authors, and empirical evidence, suggest that 5- or 10-
253253
fold cross validation should be preferred to LOO.
254254

255255

@@ -261,7 +261,7 @@ fold cross validation should be preferred to LOO.
261261
* L. Breiman, P. Spector `Submodel selection and evaluation in regression: The X-random case
262262
<http://digitalassets.lib.berkeley.edu/sdtr/ucb/text/197.pdf>`_, International Statistical Review 1992
263263
* R. Kohavi, `A Study of Cross-Validation and Bootstrap for Accuracy Estimation and Model Selection
264-
<http://www.cs.iastate.edu/~jtian/cs573/Papers/Kohavi-IJCAI-95.pdf>`_, Intl. Jnt. Conf. AI
264+
<http://www.cs.iastate.edu/~jtian/cs573/Papers/Kohavi-IJCAI-95.pdf>`_, Intl. Jnt. Conf. AI
265265
* R. Bharat Rao, G. Fung, R. Rosales, `On the Dangers of Cross-Validation. An Experimental Evaluation
266266
<http://www.siam.org/proceedings/datamining/2008/dm08_54_Rao.pdf>`_, SIAM 2008
267267
* G. James, D. Witten, T. Hastie, R Tibshirani, `An Introduction to
@@ -354,8 +354,6 @@ Example of Leave-2-Label Out::
354354
Random permutations cross-validation a.k.a. Shuffle & Split
355355
-----------------------------------------------------------
356356

357-
:class:`ShuffleSplit`
358-
359357
The :class:`ShuffleSplit` iterator will generate a user defined number of
360358
independent train / test dataset splits. Samples are first shuffled and
361359
then split into a pair of train and test sets.
@@ -379,11 +377,28 @@ Here is a usage example::
379377
validation that allows a finer control on the number of iterations and
380378
the proportion of samples in on each side of the train / test split.
381379

382-
See also
383-
--------
384-
:class:`StratifiedShuffleSplit` is a variation of *ShuffleSplit*, which returns
385-
stratified splits, *i.e* which creates splits by preserving the same
386-
percentage for each target class as in the complete set.
380+
.. note::
381+
382+
See also :class:`StratifiedShuffleSplit`: this is a variation of
383+
*ShuffleSplit*, which returns stratified splits, *i.e* which creates
384+
splits by preserving the same percentage for each target class as in
385+
the complete set.
386+
387+
Rolling window
388+
--------------
389+
390+
:class:`RollingWindow` is a strategy suited for timeseries.
391+
392+
Here is a usage example::
393+
394+
>>> rw = cross_validation.RollingWindow(5)
395+
>>> for train_index, test_index in rw:
396+
... print("%s %s" % (train_index, test_index))
397+
...
398+
[0] [1]
399+
[0 1] [2]
400+
[0 1 2] [3]
401+
[0 1 2 3] [4]
387402

388403
A note on shuffling
389404
===================

sklearn/cross_validation.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from .utils.multiclass import type_of_target
2828
from .externals.joblib import Parallel, delayed, logger
2929
from .externals.six import with_metaclass
30-
from .externals.six.moves import zip
30+
from .externals.six.moves import zip, xrange
3131
from .metrics.scorer import check_scoring
3232

3333
__all__ = ['Bootstrap',
@@ -36,6 +36,7 @@
3636
'LeaveOneOut',
3737
'LeavePLabelOut',
3838
'LeavePOut',
39+
'RollingWindow',
3940
'ShuffleSplit',
4041
'StratifiedKFold',
4142
'StratifiedShuffleSplit',
@@ -1073,6 +1074,107 @@ def __len__(self):
10731074
return self.n_iter
10741075

10751076

1077+
class RollingWindow(object):
1078+
"""Rolling window cross-validation strategy for timeseries
1079+
1080+
Provides train/test indices increasing with time.
1081+
1082+
Parameters
1083+
----------
1084+
n : int
1085+
Total number of elements in the dataset.
1086+
1087+
test_size : float, int (default is 1)
1088+
If float, should be between 0.0 and 1.0 and represent the
1089+
proportion of the dataset to include in the test split. If
1090+
int, Represents the absolute number of test samples.
1091+
1092+
train_size : float, int, or None (default is None)
1093+
If float, should be between 0.0 and 1.0 and represent the
1094+
proportion of the dataset to include in the test split. If
1095+
int, represents the absolute number of test samples. If None,
1096+
the train size grows progressively.
1097+
1098+
delay : int (default is 0)
1099+
Delay between the train and the test sets.
1100+
1101+
step : float, int, or None (default is None)
1102+
Increment. If None, step is set equal to test_size.
1103+
1104+
Examples
1105+
--------
1106+
>>> from sklearn import cross_validation
1107+
>>> rw = cross_validation.RollingWindow(5)
1108+
>>> len(rw)
1109+
4
1110+
>>> print(rw)
1111+
RollingWindow(5, test_size=1, train_size=None, delay=0, step=None)
1112+
>>> for train_index, test_index in rw:
1113+
... print("TRAIN:", train_index, "TEST:", test_index)
1114+
...
1115+
TRAIN: [0] TEST: [1]
1116+
TRAIN: [0 1] TEST: [2]
1117+
TRAIN: [0 1 2] TEST: [3]
1118+
TRAIN: [0 1 2 3] TEST: [4]
1119+
1120+
References
1121+
----------
1122+
See http://robjhyndman.com/hyndsight/tscvexample
1123+
"""
1124+
1125+
def __init__(self, n, test_size=1, train_size=None, delay=0, step=None):
1126+
self.n = n
1127+
self.test_size = test_size
1128+
self.train_size = train_size
1129+
self.delay = delay
1130+
self.step = step
1131+
1132+
self.n_train, self.n_test = _validate_shuffle_split(n,
1133+
test_size,
1134+
train_size)
1135+
self.step_ = self.n_test
1136+
if np.asarray(step).dtype.kind == 'f':
1137+
self.step_ = int(ceil(step * n))
1138+
elif np.asarray(step).dtype.kind == 'i':
1139+
self.step_ = step
1140+
1141+
def __iter__(self):
1142+
start_train = 0
1143+
1144+
first_idx = 1 if self.train_size is None else self.n_train
1145+
first_idx += self.delay
1146+
1147+
for start_test in xrange(first_idx, self.n, self.step_):
1148+
1149+
end_test = min(start_test + self.n_test, self.n)
1150+
1151+
if self.train_size is not None:
1152+
start_train = start_test - self.delay - self.n_train
1153+
1154+
train = np.arange(start_train, start_test - self.delay)
1155+
test = np.arange(start_test, end_test)
1156+
1157+
yield train, test
1158+
1159+
def __len__(self):
1160+
first_idx = self.n_train if self.train_size is not None else 1
1161+
first_idx += self.delay
1162+
1163+
l = (self.n - first_idx) / self.step_
1164+
return int(ceil(l))
1165+
1166+
def __repr__(self):
1167+
return ('%s(%i, test_size=%i, train_size=%s, delay=%i, '
1168+
'step=%s)' % (
1169+
self.__class__.__name__,
1170+
self.n,
1171+
self.test_size,
1172+
str(self.train_size),
1173+ self.delay,
1174+
str(self.step),
1175+
))
1176+
1177+
10761178
##############################################################################
10771179

10781180

sklearn/tests/test_cross_validation.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,33 @@ def test_leave_label_out_changing_labels():
461461
assert_array_equal(test, test_chan)
462462

463463

464+
def test_rolling_window_split():
465+
rw1 = cval.RollingWindow(10, test_size=0.2)
466+
rw2 = cval.RollingWindow(10, test_size=2)
467+
rw3 = cval.RollingWindow(10, test_size=np.int32(2))
468+
for typ in six.integer_types:
469+
rw4 = cval.RollingWindow(10, test_size=typ(2))
470+
for t1, t2, t3, t4 in zip(rw1, rw2, rw3, rw4):
471+
assert_array_equal(t1[0], t2[0])
472+
assert_array_equal(t2[0], t3[0])
473+
assert_array_equal(t3[0], t4[0])
474+
assert_array_equal(t1[1], t2[1])
475+
assert_array_equal(t2[1], t3[1])
476+
assert_array_equal(t3[1], t4[1])
477+
rw5 = cval.RollingWindow(5, train_size=2)
478+
assert_equal(len(rw5), 3)
479+
for t in rw5:
480+
assert_equal(len(t[0]), 2)
481+
assert_equal(len(t[1]), 1)
482+
rw6 = cval.RollingWindow(10, step=0.2)
483+
rw7 = cval.RollingWindow(10, step=2)
484+
for t1, t6, t7 in zip(rw1, rw6, rw7):
485+
assert_array_equal(t1[0], t6[0])
486+
assert_array_equal(t1[0], t7[0])
487+
assert_equal(t1[1][0], t6[1][0])
488+
assert_equal(t1[1][0], t7[1][0])
489+
490+
464491
def test_cross_val_score():
465492
clf = MockClassifier()
466493
for a in range(-10, 10):

0 commit comments

Comments
 (0)
0