8000 FEA Additional `TimeSeriesSplit` Functionality (#13204) · jayzed82/scikit-learn@65c1f62 · GitHub
[go: up one dir, main page]

8000
Skip to content

Commit 65c1f62

Browse files
kykosicKyle Kosicthomasjpfan
authored andcommitted
FEA Additional TimeSeriesSplit Functionality (scikit-learn#13204)
Co-authored-by: Kyle Kosic <kylekosic@Kyles-MacBook-Pro.local> Co-authored-by: Thomas J Fan <thomasjpfan@gmail.com>
1 parent b4293cb commit 65c1f62

File tree

4 files changed

+161
-12
lines changed

4 files changed

+161
-12
lines changed

doc/modules/cross_validation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,7 @@ Example of 3-split time series cross-validation on a dataset with 6 samples::
782782
>>> y = np.array([1, 2, 3, 4, 5, 6])
783783
>>> tscv = TimeSeriesSplit(n_splits=3)
784784
>>> print(tscv)
785-
TimeSeriesSplit(max_train_size=None, n_splits=3)
785+
TimeSeriesSplit(gap=0, max_train_size=None, n_splits=3, test_size=None)
786786
>>> for train, test in tscv.split(X):
787787
... print("%s %s" % (train, test))
788788
[0 1 2] [3]

doc/whats_new/v0.24.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ Changelog
4747
:mod:`sklearn.module`
4848
.....................
4949

50+
:mod:`sklearn.model_selection`
51+
..............................
52+
53+
- |Enhancement| :class:`model_selection.TimeSeriesSplit` has two new keyword
54+
arguments `test_size` and `gap`. `test_size` allows the out-of-sample
55+
time series length to be fixed for all folds. `gap` removes a fixed number of
56+
samples between the train and test set on each fold.
57+
:pr:`13204` by :user:`Kyle Kosic <kykosic>`.
58+
5059

5160
Code and Documentation Contributors
5261
-----------------------------------

sklearn/model_selection/_split.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,15 @@ class TimeSeriesSplit(_BaseKFold):
766766
max_train_size : int, default=None
767767
Maximum size for a single training set.
768768
769+
test_size : int, default=None
770+
Used to limit the size of the test set. Defaults to
771+
``n_samples // (n_splits + 1)``, which is the maximum allowed value
772+
with ``gap=0``.
773+
774+
gap : int, default=0
775+
Number of samples to exclude from the end of each train set before
776+
the test set.
777+
769778
Examples
770779
--------
771780
>>> import numpy as np
@@ -774,7 +783,7 @@ class TimeSeriesSplit(_BaseKFold):
774783
>>> y = np.array([1, 2, 3, 4, 5, 6])
775784
>>> tscv = TimeSeriesSplit()
776785
>>> print(tscv)
777-
TimeSeriesSplit(max_train_size=None, n_splits=5)
786+
TimeSeriesSplit(gap=0, max_train_size=None, n_splits=5, test_size=None)
778787
>>> for train_index, test_index in tscv.split(X):
779788
... print("TRAIN:", train_index, "TEST:", test_index)
780789
... X_train, X_test = X[train_index], X[test_index]
@@ -784,18 +793,45 @@ class TimeSeriesSplit(_BaseKFold):
784793
TRAIN: [0 1 2] TEST: [3]
785794
TRAIN: [0 1 2 3] TEST: [4]
786795
TRAIN: [0 1 2 3 4] TEST: [5]
796+
>>> # Fix test_size to 2 with 12 samples
797+
>>> X = np.random.randn(12, 2)
798+
>>> y = np.random.randint(0, 2, 12)
799+
>>> tscv = TimeSeriesSplit(n_splits=3, test_size=2)
800+
>>> for train_index, test_index in tscv.split(X):
801+
... print("TRAIN:", train_index, "TEST:", test_index)
802+
... X_train, X_test = X[train_index], X[test_index]
803+
... y_train, y_test = y[train_index], y[test_index]
804+
TRAIN: [0 1 2 3 4 5] TEST: [6 7]
805+
TRAIN: [0 1 2 3 4 5 6 7] TEST: [8 9]
806+
TRAIN: [0 1 2 3 4 5 6 7 8 9] TEST: [10 11]
807+
>>> # Add in a 2 period gap
808+
>>> tscv = TimeSeriesSplit(n_splits=3, test_size=2, gap=2)
809+
>>> for train_index, test_index in tscv.split(X):
810+
... print("TRAIN:", train_index, "TEST:", test_index)
811+
... X_train, X_test = X[train_index], X[test_index]
812+
... y_train, y_test = y[train_index], y[test_index]
813+
TRAIN: [0 1 2 3] TEST: [6 7]
814+
TRAIN: [0 1 2 3 4 5] TEST: [8 9]
815+
TRAIN: [0 1 2 3 4 5 6 7] TEST: [10 11]
787816
788817
Notes
789818
-----
790819
The training set has size ``i * n_samples // (n_splits + 1)
791820
+ n_samples % (n_splits + 1)`` in the ``i``th split,
792-
with a test set of size ``n_samples//(n_splits + 1)``,
821+
with a test set of size ``n_samples//(n_splits + 1)`` by default,
793822
where ``n_samples`` is the number of samples.
794823
"""
795824
@_deprecate_positional_args
796-
def __init__(self, n_splits=5, *, max_train_size=None):
825+
def __init__(self,
826+
n_splits=5,
827+
*,
828+
max_train_size=None,
829+
test_size=None,
830+
gap=0):
797831
super().__init__(n_splits, shuffle=False, random_state=None)
798832
self.max_train_size = max_train_size
833+
self.test_size = test_size
834+
self.gap = gap
799835

800836
def split(self, X, y=None, groups=None):
801837
"""Generate indices to split data into training and test set.
@@ -824,21 +860,31 @@ def split(self, X, y=None, groups=None):
824860
n_samples = _num_samples(X)
825861
n_splits = self.n_splits
826862
n_folds = n_splits + 1
863+
gap = self.gap
864+
test_size = self.test_size if self.test_size is not None \
865+
else n_samples // n_folds
866+
867+
# Make sure we have enough samples for the given split parameters
827868
if n_folds > n_samples:
828869
raise ValueError(
829-
("Cannot have number of folds ={0} greater"
830-
" than the number of samples: {1}.").format(n_folds,
831-
n_samples))
870+
(f"Cannot have number of folds={n_folds} greater"
871+
f" than the number of samples={n_samples}."))
872+
if n_samples - gap - (test_size * n_splits) <= 0:
873+
raise ValueError(
874+
(f"Too many splits={n_splits} for number of samples"
875+
f"={n_samples} with test_size={test_size} and gap={gap}."))
876+
832877
indices = np.arange(n_samples)
833-
test_size = (n_samples // n_folds)
834-
test_starts = range(test_size + n_samples % n_folds,
878+
test_starts = range(n_samples - n_splits * test_size,
835879
n_samples, test_size)
880+
836881
for test_start in test_starts:
837-
if self.max_train_size and self.max_train_size < test_start:
838-
yield (indices[test_start - self.max_train_size:test_start],
882+
train_end = test_start - gap
883+
if self.max_train_size and self.max_train_size < train_end:
884+
yield (indices[train_end - self.max_train_size:train_end],
839885
indices[test_start:test_start + test_size])
840886
else:
841-
yield (indices[:test_start],
887+
yield (indices[:train_end],
842888
indices[test_start:test_start + test_size])
843889

844890

sklearn/model_selection/tests/test_split.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,6 +1440,100 @@ def test_time_series_max_train_size():
14401440
_check_time_series_max_train_size(splits, check_splits, max_train_size=2)
14411441

14421442

1443+
def test_time_series_test_size():
1444+
X = np.zeros((10, 1))
1445+
1446+
# Test alone
1447+
splits = TimeSeriesSplit(n_splits=3, test_size=3).split(X)
1448+
1449+
train, test = next(splits)
1450+
assert_array_equal(train, [0])
1451+
assert_array_equal(test, [1, 2, 3])
1452+
1453+
train, test = next(splits)
1454+
assert_array_equal(train, [0, 1, 2, 3])
1455+
assert_array_equal(test, [4, 5, 6])
1456+
1457+
train, test = next(splits)
1458+
assert_array_equal(train, [0, 1, 2, 3, 4, 5, 6])
1459+
assert_array_equal(test, [7, 8, 9])
1460+
1461+
# Test with max_train_size
1462+
splits = TimeSeriesSplit(n_splits=2, test_size=2,
1463+
max_train_size=4).split(X)
1464+
1465+
train, test = next(splits)
1466+
assert_array_equal(train, [2, 3, 4, 5])
1467+
assert_array_equal(test, [6, 7])
1468+
1469+
train, test = next(splits)
1470+
assert_array_equal(train, [4, 5, 6, 7])
1471+
assert_array_equal(test, [8, 9])
1472+
1473+
# Should fail with not enough data points for configuration
1474+
with pytest.raises(ValueError, match="Too many splits.*with test_size"):
1475+
splits = TimeSeriesSplit(n_splits=5, test_size=2).split(X)
1476+
next(splits)
1477+
1478+
1479+
def test_time_series_gap():
1480+
X = np.zeros((10, 1))
1481+
1482+
# Test alone
1483+
splits = TimeSeriesSplit(n_splits=2, gap=2).split(X)
1484+
1485+
train, test = next(splits)
1486+
assert_array_equal(train, [ 111C 0, 1])
1487+
assert_array_equal(test, [4, 5, 6])
1488+
1489+
train, test = next(splits)
1490+
assert_array_equal(train, [0, 1, 2, 3, 4])
1491+
assert_array_equal(test, [7, 8, 9])
1492+
1493+
# Test with max_train_size
1494+
splits = TimeSeriesSplit(n_splits=3, gap=2, max_train_size=2).split(X)
1495+
1496+
train, test = next(splits)
1497+
assert_array_equal(train, [0, 1])
1498+
assert_array_equal(test, [4, 5])
1499+
1500+
train, test = next(splits)
1501+
assert_array_equal(train, [2, 3])
1502+
assert_array_equal(test, [6, 7])
1503+
1504+
train, test = next(splits)
1505+
assert_array_equal(train, [4, 5])
1506+
assert_array_equal(test, [8, 9])
1507+
1508+
# Test with test_size
1509+
splits = TimeSeriesSplit(n_splits=2, gap=2,
1510+
max_train_size=4, test_size=2).split(X)
1511+
1512+
train, test = next(splits)
1513+
assert_array_equal(train, [0, 1, 2, 3])
1514+
assert_array_equal(test, [6, 7])
1515+
1516+
train, test = next(splits)
1517+
assert_array_equal(train, [2, 3, 4, 5])
1518+
assert_array_equal(test, [8, 9])
1519+
1520+
# Test with additional test_size
1521+
splits = TimeSeriesSplit(n_splits=2, gap=2, test_size=3).split(X)
1522+
1523+
train, test = next(splits)
1524+
assert_array_equal(train, [0, 1])
1525+
assert_array_equal(test, [4, 5, 6])
1526+
1527+
train, test = next(splits)
1528+
assert_array_equal(train, [0, 1, 2, 3, 4])
1529+
assert_array_equal(test, [7, 8, 9])
1530+
1531+
# Verify proper error is thrown
1532+
with pytest.raises(ValueError, match="Too many splits.*and gap"):
1533+
splits = TimeSeriesSplit(n_splits=4, gap=2).split(X)
1534+
next(splits)
1535+
1536+
14431537
def test_nested_cv():
14441538
# Test if nested cross validation works with different combinations of cv
14451539
rng = np.random.RandomState(0)

0 commit comments

Comments
 (0)
0