8000 ENH Added RollingWindowCV · scikit-learn/scikit-learn@cf6047e · GitHub
[go: up one dir, main page]

Skip to content

Commit cf6047e

Browse files
committed
ENH Added RollingWindowCV
1 parent aca8f20 commit cf6047e

File tree

4 files changed

+394
-0
lines changed

4 files changed

+394
-0
lines changed

doc/whats_new/v1.2.rst

+3
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,9 @@ Changelog
413413
nan score is correctly set to the maximum possible rank, rather than
414414
`np.iinfo(np.int32).min`. :pr:`24141` by :user:`Loïc Estève <lesteve>`.
415415

416+
- |Feature| Adds :class:`model_selection.RollingWindowCV`.
417+
:pr:`24589` by :user:`Maxwell Schmidt <MSchmidt99>`
418+
416419
:mod:`sklearn.multioutput`
417420
..........................
418421

sklearn/model_selection/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ._split import GroupKFold
77
from ._split import StratifiedKFold
88
from ._split import TimeSeriesSplit
9+
from ._split import RollingWindowCV
910
from ._split import LeaveOneGroupOut
1011
from ._split import LeaveOneOut
1112
from ._split import LeavePGroupsOut
@@ -46,6 +47,7 @@
4647
"BaseShuffleSplit",
4748
"GridSearchCV",
4849
"TimeSeriesSplit",
50+
"RollingWindowCV",
4951
"KFold",
5052
"GroupKFold",
5153
"GroupShuffleSplit",

sklearn/model_selection/_split.py

+277
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# Raghav RV <rvraghav93@gmail.com>
1010
# Leandro Hermida <hermidal@cs.umd.edu>
1111
# Rodion Martynov <marrodion@gmail.com>
12+
# Maxwell Schmidt <maxwelljschmidt99@gmail.com>
1213
# License: BSD 3 clause
1314

1415
from collections.abc import Iterable
@@ -1153,6 +1154,282 @@ def split(self, X, y=None, groups=None):
11531154
)
11541155

11551156

1157+
class RollingWindowCV(_BaseKFold):
1158+
"""
1159+
A variant of TimeSeriesSplit which yields equally sized rolling windows, which
1160+
allows for more consistent parameter tuning.
1161+
1162+
If a time column is passed then the windows will be sized according to the time
1163+
steps given without blending (this is useful for longitudinal data).
1164+
1165+
Parameters
1166+
----------
1167+
n_splits : int, default=5
1168+
Number of splits.
1169+
1170+
time_column : Iterable, default=None
1171+
Column of the dataset containing dates. Will function identically with `None`
1172+
when observations are not longitudinal. If observations are longitudinal then
1173+
will facilitate splitting train and validation without date bleeding.
1174+
1175+
train_prop : float, default=0.8
1176+
Proportion of each window which should be allocated to training. If
1177+
`buffer_prop` is given then true training proportion will be
1178+
`train_prop - buffer_prop`.
1179+
Validation proportion will always be `1 - train_prop`.
1180+
1181+
buffer_prop : float, default=0.0
1182+
The proportion of each window which should be allocated to nothing. Cuts into
1183+
`train_prop`.
1184+
1185+
slide : float, default=0.0
1186+
`slide + 1` is the number of validation lenghts to step by when generating
1187+
windows. A value between -1.0 and 0.0 will create nearly stationary windows,
1188+
and should be avoided unless for some odd reason it is needed.
1189+
1190+
bias : {'left', 'right', 'train'}, default='train'
1191+
A 'left' `bias` will yeld indicies beginning at 0 and not necessarily ending
1192+
at N. A 'right' `bias` will yield indicies not necessarily beginning with 0 but
1193+
will however end at N. A 'train' `bias` will yield indices from 0 to N, with
1194+
the overhang which would have been present with 'right' or 'left' `bias`
1195+
allocated to the training window.
1196+
1197+
max_long_samples : int, default=None
1198+
If the data is longitudinal and this variable is given, the number of
1199+
observations at each time step will be limited to the first `max_long_samples`
1200+
samples.
1201+
1202+
expanding : bool, default=False
1203+
When `True` each window will begin with the first time step. This will yeild
1204+
training indicies which increase in number as the window moves forwards.
1205+
1206+
Examples
1207+
--------
1208+
>>> import numpy as np
1209+
>>> from sklearn.model_selection import RollingWindowCV
1210+
>>> X = np.random.randn(20, 2)
1211+
>>> y = np.random.randint(0, 2, 20)
1212+
>>> rwcv = RollingWindowCV(n_splits=5, bias="right")
1213+
>>> for train_index, test_index in rwcv.split(X):
1214+
... print("TRAIN:", train_index, "TEST:", test_index)
1215+
... X_train, X_test = X[train_index], X[test_index]
1216+
... y_train, y_test = y[train_index], y[test_index]
1217+
TRAIN: [1 2 3 4 5 6 7 8 9] TEST: [10 11]
1218+
TRAIN: [ 3 4 5 6 7 8 9 10 11] TEST: [12 13]
1219+
TRAIN: [ 5 6 7 8 9 10 11 12 13] TEST: [14 15]
1220+
TRAIN: [ 7 8 9 10 11 12 13 14 15] TEST: [16 17]
1221+
TRAIN: [ 9 10 11 12 13 14 15 16 17] TEST: [18 19]
1222+
>>> # Use a time column with longitudinal data and reduce train proportion
1223+
>>> time_col = np.tile(np.arange(16), 2)
1224+
>>> X = np.arange(64).reshape(32, 2)
1225+
>>> y = np.arange(32)
1226+
>>> rwcv = RollingWindowCV(
1227+
... time_column=time_col, train_prop=0.5, n_splits=5, bias='right'
1228+
... )
1229+
>>> for train_index, test_index in rwcv.split(X):
1230+
... print("TRAIN:", train_index, "TEST:", test_index)
1231+
... X_train, X_test = X[train_index], X[test_index]
1232+
... y_train, y_test = y[train_index], y[test_index]
1233+
TRAIN: [ 1 17 2 18 3 19 4 20 5 21] TEST: [ 6 22 7 23]
1234+
TRAIN: [ 3 19 4 20 5 21 6 22 7 23] TEST: [ 8 24 9 25]
1235+
TRAIN: [ 5 21 6 22 7 23 8 24 9 25] TEST: [10 26 11 27]
1236+
TRAIN: [ 7 23 8 24 9 25 10 26 11 27] TEST: [12 28 13 29]
1237+
TRAIN: [ 9 25 10 26 11 27 12 28 13 29] TEST: [14 30 15 31]
1238+
>>> # Bias the indicies to the start of the time column
1239+
>>> rwcv = RollingWindowCV(
1240+
... time_column=time_col, train_prop=0.5, n_splits=5, bias='left'
1241+
... )
1242+
>>> for train_index, test_index in rwcv.split(X):
1243+
... print("TRAIN:", train_index, "TEST:", test_index)
1244+
... X_train, X_test = X[train_index], X[test_index]
1245+
... y_train, y_test = y[train_index], y[test_index]
1246+
TRAIN: [ 0 16 1 17 2 18 3 19 4 20] TEST: [ 5 21 6 22]
1247+
TRAIN: [ 2 18 3 19 4 20 5 21 6 22] TEST: [ 7 23 8 24]
1248+
TRAIN: [ 4 20 5 21 6 22 7 23 8 24] TEST: [ 9 25 10 26]
1249+
TRAIN: [ 6 22 7 23 8 24 9 25 10 26] TEST: [11 27 12 28]
1250+
TRAIN: [ 8 24 9 25 10 26 11 27 12 28] TEST: [13 29 14 30]
1251+
>>> # Introduce a buffer zone between train and validation, and slide window
1252+
>>> # by an additional validation size between windows.
1253+
>>> X = np.arange(25)
1254+
>>> Y = np.arange(25)[::-1]
1255+
>>> rwcv = RollingWindowCV(
1256+
... train_prop=0.6, n_splits=2, buffer_prop=0.2, slide=1.0, bias="right"
1257+
... )
1258+
>>> for train_index, test_index in rwcv.split(X):
1259+
... print("TRAIN:", train_index, "TEST:", test_index)
1260+
... X_train, X_test = X[train_index], X[test_index]
1261+
... y_train, y_test = y[train_index], y[test_index]
1262+
...
1263+
TRAIN: [2 3 4 5 6 7] TEST: [10 11 12 13 14]
1264+
TRAIN: [12 13 14 15 16 17] TEST: [20 21 22 23 24]
1265+
"""
1266+
1267+
def __init__(
1268+
self,
1269+
n_splits=4,
1270+
*,
1271+
time_column=None,
1272+
train_prop=0.8,
1273+
buffer_prop=0.0,
1274+
slide=0.0,
1275+
bias="train",
1276+
max_long_samples=None,
1277+
expanding=False,
1278+
):
1279+
if buffer_prop > train_prop:
1280+
raise ValueError(
1281+
"Buffer proportion cannot be greater than training proportion."
1282+
)
1283+
if slide < -1.0:
1284+
raise ValueError("slide cannot be less than -1.0")
1285+
if bias not in ("right", "left", "train"):
1286+
raise ValueError("Invalid value for bias.")
1287+
1288+
self.n_splits = n_splits
1289+
self.time_column = time_column
1290+
self.train_prop = train_prop
1291+
self.buffer_prop = buffer_prop
1292+
test_prop = 1 - train_prop
1293+
self.batch_size = (1 + (test_prop * (slide + 1) * (n_splits - 1))) ** (-1)
1294+
self.slide = slide
1295+
self.bias = bias
1296+
if max_long_samples is not None:
1297+
max_long_samples += 1 # index slice end is exclusivve
1298+
self.max_long_samples = max_long_samples
1299+
self.expanding = expanding
1300+
1301+
def split(self, X, y=None, groups=None):
1302+
"""Generate indices to split data into training and test set.
1303+
1304+
Parameters
1305+
----------
1306+
X : array-like of shape (n_samples, n_features)
1307+
Training data, where `n_samples` is the number of samples
1308+
and `n_features` is the number of features.
1309+
1310+
y : array-like of shape (n_samples,)
1311+
Always ignored, exists for compatibility.
1312+
1313+
groups : array-like of shape (n_samples,)
1314+
Always ignored, exists for compatibility.
1315+
1316+
Yields
1317+
------
1318+
train : ndarray
1319+
The training set indices for that split.
1320+
1321+
test : ndarray
1322+
The testing set indices for that split.
1323+
"""
1324+
if self.time_column is None:
1325+
X, y, groups = indexable(X, y, groups)
1326+
n_samples = _num_samples(X)
1327+
else:
1328+
X = self.time_column
1329+
X, y, groups = indexable(X, y, groups)
1330+
X_unique = np.array(list(dict.fromkeys(X)))
1331+
n_samples = _num_samples(X_unique)
1332+
1333+
if self.n_splits > n_samples:
1334+
raise ValueError(
1335+
f"Cannot have number of folds={self.n_splits} greater"
1336+
f" than the number of samples={n_samples}."
1337+
)
1338+
1339+
if isinstance(self.batch_size, float) and self.batch_size < 1:
1340+
length_per_iter = int(n_samples * self.batch_size)
1341+
elif isinstance(self.batch_size, int) and self.batch_size >= 1:
1342+
length_per_iter = self.batch_size
1343+
else:
1344+
raise ValueError(
1345+
"batch_size must be decimal between 0 and 1.0 or whole number greater "
1346+
f"than or equal to 1 (got {self.batch_size})."
1347+
)
1348+
1349+
test_size = int(length_per_iter * (1 - self.train_prop))
1350+
if test_size < 1:
1351+
raise ValueError(
1352+
"Inferred batch size with batch test proportion of "
1353+
f"{1 - self.train_prop:0.2f}, slide of {self.slide:0.2f}, and "
1354+
f"n_splits of {self.n_splits} is {length_per_iter}. Each batches "
1355+
"testing length is thus "
1356+
f"{length_per_iter * (1 - self.train_prop):0.2f}, which must not be "
1357+
"less than 1.0"
1358+
)
1359+
buffer_size = int(length_per_iter * self.buffer_prop)
1360+
train_size = length_per_iter - test_size - buffer_size
1361+
1362+
used_indices_len = (
1363+
test_size * (self.slide + 1) * (self.n_splits - 1) + length_per_iter
1364+
)
1365+
# difference is expected to be 1 or 0, so only effects data sets with
1366+
# very few samples.
1367+
if n_samples - used_indices_len >= test_size:
1368+
train_size += test_size
1369+
length_per_iter += test_size
1370+
1371+
if self.bias == "left":
1372+
train_starts = range(
1373+
0, n_samples - length_per_iter + 1, int(test_size * (self.slide + 1))
1374+
)
1375+
else:
1376+
overhang = (n_samples - length_per_iter) % int(test_size * (self.slide + 10000 1))
1377+
if self.bias == "right":
1378+
train_starts = range(
1379+
overhang,
1380+
n_samples - length_per_iter + 1,
1381+
int(test_size * (self.slide + 1)),
1382+
)
1383+
elif self.bias == "train":
1384+
length_per_iter += overhang
1385+
train_size += overhang
1386+
train_starts = range(
1387+
0,
1388+
n_samples - length_per_iter + 1,
1389+
int(test_size * (self.slide + 1)),
1390+
)
1391+
1392+
if self.time_column is None:
1393+
indices = np.arange(n_samples)
1394+
for train_start in train_starts:
1395+
yield (
1396+
indices[
1397+
0 if self.expanding else train_start : train_start + train_size
1398+
],
1399+
indices[
1400+
train_start
1401+
+ train_size
1402+
+ buffer_size : train_start
1403+
+ length_per_iter
1404+
],
1405+
)
1406+
else:
1407+
for train_start in train_starts:
1408+
yield (
1409+
np.concatenate(
1410+
[
1411+
np.argwhere(X == x_u).flatten()[: self.max_long_samples]
1412+
for x_u in X_unique[
1413+
0
1414+
if self.expanding
1415+
else train_start : train_start + train_size
1416+
]
1417+
]
1418+
),
1419+
np.concatenate(
1420+
[
1421+
np.argwhere(X == x_u).flatten()[: self.max_long_samples]
1422+
for x_u in X_unique[
1423+
train_start
1424+
+ train_size
1425+
+ buffer_size : train_start
1426+
+ length_per_iter
1427+
]
1428+
]
1429+
),
1430+
)
1431+
1432+
11561433
class LeaveOneGroupOut(BaseCrossValidator):
11571434
"""Leave One Group Out cross-validator
11581435

0 commit comments

Comments
 (0)
0