8000 ENH Added RollingWindowCV · scikit-learn/scikit-learn@e0d7ebf · GitHub
[go: up one dir, main page]

Skip to content

Commit e0d7ebf

Browse files
committed
ENH Added RollingWindowCV
1 parent aca8f20 commit e0d7ebf

File tree

4 files changed

+337
-0
lines changed

4 files changed

+337
-0
lines changed

doc/whats_new/v1.2.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,9 @@ Changelog
413413
nan score is correctly set to the maximum possible rank, rather than
414414
`np.iinfo(np.int32).min`. :pr:`24141` by :user:`Loïc Estève <lesteve>`.
415415

416+
- |Feature| Adds :class:`model_selection.RollingWindowCV`.
417+
:pr:`24589` by :user:`Maxwell Schmidt <MSchmidt99>`
418+
416419
:mod:`sklearn.multioutput`
417420
..........................
418421

sklearn/model_selection/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ._split import GroupKFold
77
from ._split import StratifiedKFold
88
from ._split import TimeSeriesSplit
9+
from ._split import RollingWindowCV
910
from ._split import LeaveOneGroupOut
1011
from ._split import LeaveOneOut
1112
from ._split import LeavePGroupsOut
@@ -46,6 +47,7 @@
4647
"BaseShuffleSplit",
4748
"GridSearchCV",
4849
"TimeSeriesSplit",
50+
"RollingWindowCV",
4951
"KFold",
5052
"GroupKFold",
5153
"GroupShuffleSplit",

sklearn/model_selection/_split.py

Lines changed: 256 additions & 0 deletions
9E88
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# Raghav RV <rvraghav93@gmail.com>
1010
# Leandro Hermida <hermidal@cs.umd.edu>
1111
# Rodion Martynov <marrodion@gmail.com>
12+
# Maxwell Schmidt <maxwelljschmidt99@gmail.com>
1213
# License: BSD 3 clause
1314

1415
from collections.abc import Iterable
@@ -1153,6 +1154,261 @@ def split(self, X, y=None, groups=None):
11531154
)
11541155

11551156

1157+
class RollingWindowCV(_BaseKFold):
1158+
"""
1159+
A variant of TimeSeriesSplit which yields equally sized rolling windows, which
1160+
allows for more consistent parameter tuning.
1161+
1162+
If a time column is passed then the windows will be sized according to the time
1163+
steps given without blending (this is useful for longitudinal data).
1164+
1165+
Parameters
1166+
----------
1167+
n_splits : int, default=5
1168+
Number of splits.
1169+
1170+
time_column : Iterable, default=None
1171+
Column of the dataset containing dates. Will function identically with `None`
1172+
when observations are not longitudinal. If observations are longitudinal then
1173+
will facilitate splitting train and validation without date bleeding.
1174+
1175+
train_prop : float, default=0.8
1176+
Proportion of each window which should be allocated to training. If
1177+
`buffer_prop` is given then true training proportion will be
1178+
`train_prop - buffer_prop`.
1179+
Validation proportion will always be `1 - train_prop`.
1180+
1181+
buffer_prop : float, default=0.0
1182+
The proportion of each window which should be allocated to nothing. Cuts into
1183+
`train_prop`.
1184+
1185+
slide : float, default=0.0
1186+
`slide + 1` is the number of validation lenghts to step by when generating
1187+
windows. A value between -1.0 and 0.0 will create nearly stationary windows,
1188+
and should be avoided unless for some odd reason it is needed.
1189+
1190+
bias : {'left', 'right'}, default='right'
1191+
A 'left' `bias` will yeld indicies beginning at 0 and not necessarily ending
1192+
at N. A 'right' `bias` will yield indicies not necessarily beginning with 0 but
1193+
will however end at N.
1194+
1195+
max_long_samples : int, default=None
1196+
If the data is longitudinal and this variable is given, the number of
1197+
observations at each time step will be limited to the first `max_long_samples`
1198+
samples.
1199+
1200+
Examples
1201+
--------
1202+
>>> import numpy as np
1203+
>>> from sklearn.model_selection import RollingWindowCV
1204+
>>> X = np.random.randn(20, 2)
1205+
>>> y = np.random.randint(0, 2, 20)
1206+
>>> rwcv = RollingWindowCV(n_splits=5)
1207+
>>> for train_index, test_index in rwcv.split(X):
1208+
... print("TRAIN:", train_index, "TEST:", test_index)
1209+
... X_train, X_test = X[train_index], X[test_index]
1210+
... y_train, y_test = y[train_index], y[test_index]
1211+
TRAIN: [1 2 3 4 5 6 7 8 9] TEST: [10 11]
1212+
TRAIN: [ 3 4 5 6 7 8 9 10 11] TEST: [12 13]
1213+
TRAIN: [ 5 6 7 8 9 10 11 12 13] TEST: [14 15]
1214+
TRAIN: [ 7 8 9 10 11 12 13 14 15] TEST: [16 17]
1215+
TRAIN: [ 9 10 11 12 13 14 15 16 17] TEST: [18 19]
1216+
>>> # Use a time column with longitudinal data and reduce train proportion
1217+
>>> time_col = np.tile(np.arange(16), 2)
1218+
>>> X = np.arange(64).reshape(32, 2)
1219+
>>> y = np.arange(32)
1220+
>>> rwcv = RollingWindowCV(
1221+
... time_column=time_col, train_prop=0.5, n_splits=5, bias='right'
1222+
... )
1223+
>>> for train_index, test_index in rwcv.split(X):
1224+
... print("TRAIN:", train_index, "TEST:", test_index)
1225+
... X_train, X_test = X[train_index], X[test_index]
1226+
... y_train, y_test = y[train_index], y[test_index]
1227+
TRAIN: [ 1 17 2 18 3 19 4 20 5 21] TEST: [ 6 22 7 23]
1228+
TRAIN: [ 3 19 4 20 5 21 6 22 7 23] TEST: [ 8 24 9 25]
1229+
TRAIN: [ 5 21 6 22 7 23 8 24 9 25] TEST: [10 26 11 27]
1230+
TRAIN: [ 7 23 8 24 9 25 10 26 11 27] TEST: [12 28 13 29]
1231+
TRAIN: [ 9 25 10 26 11 27 12 28 13 29] TEST: [14 30 15 31]
1232+
>>> # Bias the indicies to the start of the time column
1233+
>>> rwcv = RollingWindowCV(
1234+
... time_column=time_col, train_prop=0.5, n_splits=5, bias='left'
1235+
... )
1236+
>>> for train_index, test_index in rwcv.split(X):
1237+
... print("TRAIN:", train_index, "TEST:", test_index)
1238+
... X_train, X_test = X[train_index], X[test_index]
1239+
... y_train, y_test = y[train_index], y[test_index]
1240+
TRAIN: [ 0 16 1 17 2 18 3 19 4 20] TEST: [ 5 21 6 22]
1241+
TRAIN: [ 2 18 3 19 4 20 5 21 6 22] TEST: [ 7 23 8 24]
1242+
TRAIN: [ 4 20 5 21 6 22 7 23 8 24] TEST: [ 9 25 10 26]
1243+
TRAIN: [ 6 22 7 23 8 24 9 25 10 26] TEST: [11 27 12 28]
1244+
TRAIN: [ 8 24 9 25 10 26 11 27 12 28] TEST: [13 29 14 30]
1245+
>>> # Introduce a buffer zone between train and validation, and slide window
1246+
>>> # by an additional validation size between windows.
1247+
>>> X = np.arange(25)
1248+
>>> Y = np.arange(25)[::-1]
1249+
>>> rwcv = RollingWindowCV(train_prop=0.6, n_splits=2, buffer_prop=0.2, slide=1.0)
1250+
>>> for train_index, test_index in rwcv.split(X):
1251+
... print("TRAIN:", train_index, "TEST:", test_index)
1252+
... X_train, X_test = X[train_index], X[test_index]
1253+
... y_train, y_test = y[train_index], y[test_index]
1254+
...
1255+
TRAIN: [2 3 4 5 6 7] TEST: [10 11 12 13 14]
1256+
TRAIN: [12 13 14 15 16 17] TEST: [20 21 22 23 24]
1257+
"""
1258+
1259+
def __init__(
1260+
self,
1261+
n_splits=4,
1262+
*,
1263+
time_column=None,
1264+
train_prop=0.8,
1265+
buffer_prop=0.0,
1266+
slide=0.0,
1267+
bias="right",
1268+
max_long_samples=None,
1269+
):
1270+
if buffer_prop > train_prop:
1271+
raise ValueError(
1272+
"Buffer proportion cannot be greater than training proportion."
1273+
)
1274+
if slide < -1.0:
1275+
raise ValueError("slide cannot be less than -1.0")
1276+
1277+
self.n_splits = n_splits
1278+
self.time_column = time_column
1279+
self.train_prop = train_prop
1280+
self.buffer_prop = buffer_prop
1281+
test_prop = 1 - train_prop
1282+
self.batch_size = (1 + (test_prop * (slide + 1) * (n_splits - 1))) ** (-1)
1283+
self.slide = slide
1284+
self.bias = bias
1285+
if max_long_samples is not None:
1286+
max_long_samples += 1 # index slice end is exclusivve
1287+
self.max_long_samples = max_long_samples
1288+
1289+
def split(self, X, y=None, groups=None):
1290+
"""Generate indices to split data into training and test set.
1291+
1292+
Parameters
1293+
----------
1294+
X : array-like of shape (n_samples, n_features)
1295+
Training data, where `n_samples` is the number of samples
1296+
and `n_features` is the number of features.
1297+
1298+
y : array-like of shape (n_samples,)
1299+
Always ignored, exists for compatibility.
1300+
1301+
groups : array-like of shape (n_samples,)
1302+
Always ignored, exists for compatibility.
1303+
1304+
Yields
1305+
------
1306+
train : ndarray
1307+
The training set indices for that split.
1308+
1309+
test : ndarray
1310+
The testing set indices for that split.
1311+
"""
1312+
if self.time_column is None:
1313+
X, y, groups = indexable(X, y, groups)
1314+
n_samples = _num_samples(X)
1315+
else:
1316+
X = self.time_column
1317+
X, y, groups = indexable(X, y, groups)
1318+
X_time = np.array(list(dict.fromkeys(X)))
1319+
n_samples = _num_samples(X_time)
1320+
1321+
if self.n_splits > n_samples:
1322+
raise ValueError(
1323+
f"Cannot have number of folds={self.n_splits} greater"
1324+
f" than the number of samples={n_samples}."
1325+
)
1326+
1327+
if isinstance(self.batch_size, float) and self.batch_size < 1:
1328+
length_per_iter = int(n_samples * self.batch_size)
1329+
elif isinstance(self.batch_size, int) and self.batch_size >= 1:
1330+
length_per_iter = self.batch_size
1331+
else:
1332+
raise ValueError(
1333+
"batch_size must be decimal between 0 and 1.0 or whole number greater "
1334+
f"than or equal to 1 (got {self.batch_size})."
1335+
)
1336+
1337+
test_size = int(length_per_iter * (1 - self.train_prop))
1338+
if test_size < 1:
1339+
raise ValueError(
1340+
"Inferred batch size with batch test proportion of "
1341+
f"{1 - self.train_prop:0.2f}, slide of {self.slide:0.2f}, and "
1342+
f"n_splits of {self.n_splits} is {length_per_iter}. Each batches "
1343+
"testing length is thus "
1344+
f"{length_per_iter * (1 - self.train_prop):0.2f}, which must not be "
1345+
"less than 1.0"
1346+
)
1347+
buffer_size = int(length_per_iter * self.buffer_prop)
1348+
train_size = length_per_iter - test_size - buffer_size
1349+
1350+
used_indices_len = (
1351+
test_size * (self.slide + 1) * (self.n_splits - 1) + length_per_iter
1352+
)
1353+
# difference is expected to be 1 or 0, so only effects data sets with
1354+
# very few samples.
1355+
if n_samples - used_indices_len >= test_size:
1356+
train_size += test_size
1357+
length_per_iter += test_size
1358+
1359+
if self.bias == "left":
1360+
train_starts = range(
1361+
0, n_samples - length_per_iter + 1, int(test_size * (self.slide + 1))
1362+
)
1363+
elif self.bias == "right":
1364+
overhang = (n_samples - length_per_iter) % int(test_size * (self.slide + 1))
1365+
train_starts = range(
1366+
overhang,
1367+
n_samples - length_per_iter + 1,
1368+
int(test_size * (self.slide + 1)),
1369+
)
1370+
else:
1371+
raise ValueError(f"{self.bias} is not a valid option for bias.")
1372+
1373+
if self.time_column is None:
1374+
indices = np.arange(n_samples)
1375+
for train_start in train_starts:
1376+
yield (
1377+
indices[train_start : train_start + train_size],
1378+
indices[
1379+
train_start
1380+
+ train_size
1381+
+ buffer_size : train_start
1382+
+ length_per_iter
1383+
],
1384+
)
1385+
else:
1386+
for train_start in train_starts:
1387+
yield (
1388+
np.concatenate(
1389+
[
1390+
np.array([i for i, x2 in enumerate(X) if x == x2])[
1391+
: self.max_long_samples
1392+
]
1393+
for x in X_time[train_start : train_start + train_size]
1394+
]
1395+
).astype(int),
1396+
np.concatenate(
1397+
[
1398+
np.array([i for i, x2 in enumerate(X) if x == x2])[
1399+
: self.max_long_samples
1400+
]
1401+
for x in X_time[
1402+
train_start
1403+
+ train_size
1404+
+ buffer_size : train_start
1405+
+ length_per_iter
1406+
]
1407+
]
1408+
).astype(int),
1409+
)
1410+
1411+
11561412
class LeaveOneGroupOut(BaseCrossValidator):
11571413
"""Leave One Group Out cross-validator
11581414

sklearn/model_selection/tests/test_split.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sklearn.model_selection import StratifiedKFold
2323
from sklearn.model_selection import GroupKFold
2424
from sklearn.model_selection import TimeSeriesSplit
25+
from sklearn.model_selection import RollingWindowCV
2526
from sklearn.model_selection import LeaveOneOut
2627
from sklearn.model_selection import LeaveOneGroupOut
2728
from sklearn.model_selection import LeavePOut
@@ -172,6 +173,7 @@ def test_2d_y():
172173
LeavePGroupsOut(n_groups=2),
173174
GroupKFold(n_splits=3),
174175
TimeSeriesSplit(),
176+
RollingWindowCV(),
175177
PredefinedSplit(test_fold=groups),
176178
]
177179
for splitter in splitters:
@@ -1782,6 +1784,80 @@ def test_time_series_gap():
17821784
next(splits)
17831785

17841786

1787+
def test_rolling_window_cv():
1788+
X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
1789+
1790+
# Should fail if there are more folds than samples
1791+
with pytest.raises(ValueError, match="Cannot have number of folds.*greater"):
1792+
next(RollingWindowCV(n_splits=9).split(X))
1793+
1794+
rwcv = RollingWindowCV(2)
1795+
1796+
# Manually check that Rolling Window CV preserves the data
1797+
# ordering on toy datasets
1798+
splits = rwcv.split(X)
1799+
train, test = next(splits)
1800+
assert_array_equal(train, [0, 1, 2, 3, 4, 5])
1801+
assert_array_equal(test, [6])
1802+
1803+
train, test = next(splits)
1804+
assert_array_equal(train, [1, 2, 3, 4, 5, 6])
1805+
assert_array_equal(test, [7])
1806+
1807+
# Check get_n_splits returns the correct number of splits
1808+
splits = RollingWindowCV(2).split(X)
1809+
n_splits_actual = len(list(splits))
1810+
assert n_splits_actual == rwcv.get_n_splits()
1811+
assert n_splits_actual == 2
1812+
1813+
1814+
def test_rolling_window_longitudinal():
1815+
time_col = np.tile(np.arange(8), 2)
1816+
X = np.arange(32).reshape(16, 2)
1817+
1818+
splits = RollingWindowCV(time_column=time_col, train_prop=0.5, n_splits=3).split(X)
1819+
train, test = next(splits)
1820+
assert_array_equal(train, [0, 8, 1, 9])
1821+
assert_array_equal(test, [2, 10, 3, 11])
1822+
train, test = next(splits)
1823+
assert_array_equal(train, [2, 10, 3, 11])
1824+
assert_array_equal(test, [4, 12, 5, 13])
1825+
train, test = next(splits)
1826+
assert_array_equal(train, [4, 12, 5, 13])
1827+
assert_array_equal(test, [6, 14, 7, 15])
1828+
1829+
1830+
def test_rolling_window_params():
1831+
X = np.zeros((40, 1))
1832+
1833+
# slide
1834+
splits = RollingWindowCV(2, slide=1.0).split(X)
1835+
train, test = next(splits)
1836+
assert_array_equal(train, np.arange(2, 25))
1837+
assert_array_equal(test, np.arange(25, 30))
1838+
train, test = next(splits)
1839+
assert_array_equal(train, np.arange(12, 35))
1840+
assert_array_equal(test, np.arange(35, 40))
1841+
1842+
# left bias
1843+
splits = RollingWindowCV(2, slide=1.0, bias="left").split(X)
1844+
train, test = next(splits)
1845+
assert_array_equal(train, np.arange(0, 23))
1846+
assert_array_equal(test, np.arange(23, 28))
1847+
train, test = next(splits)
1848+
assert_array_equal(train, np.arange(10, 33))
1849+
assert_array_equal(test, np.arange(33, 38))
1850+
1851+
# buffer
1852+
splits = RollingWindowCV(2, slide=1.0, buffer_prop=0.1).split(X)
1853+
train, test = next(splits)
1854+
assert_array_equal(train, np.arange(2, 23))
1855+
assert_array_equal(test, np.arange(25, 30))
1856+
train, test = next(splits)
1857+
assert_array_equal(train, np.arange(12, 33))
1858+
assert_array_equal(test, np.arange(35, 40))
1859+
1860+
17851861
def test_nested_cv():
17861862
# Test if nested cross validation works with different combinations of cv
17871863
rng = np.random.RandomState(0)

0 commit comments

Comments
 (0)
0