diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 6a1b05badd4a8..18b1d0784e4c8 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -417,6 +417,9 @@ Changelog nan score is correctly set to the maximum possible rank, rather than `np.iinfo(np.int32).min`. :pr:`24141` by :user:`Loïc Estève `. +- |Feature| Adds :class:`model_selection.RollingWindowCV`. + :pr:`24589` by :user:`Maxwell Schmidt ` + :mod:`sklearn.multioutput` .......................... diff --git a/sklearn/model_selection/__init__.py b/sklearn/model_selection/__init__.py index a481f5db72fdf..1191c09005fb2 100644 --- a/sklearn/model_selection/__init__.py +++ b/sklearn/model_selection/__init__.py @@ -6,6 +6,7 @@ from ._split import GroupKFold from ._split import StratifiedKFold from ._split import TimeSeriesSplit +from ._split import RollingWindowCV from ._split import LeaveOneGroupOut from ._split import LeaveOneOut from ._split import LeavePGroupsOut @@ -46,6 +47,7 @@ "BaseShuffleSplit", "GridSearchCV", "TimeSeriesSplit", + "RollingWindowCV", "KFold", "GroupKFold", "GroupShuffleSplit", diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 61e2caa4ca7fa..2987a142ea3f2 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -9,6 +9,7 @@ # Raghav RV # Leandro Hermida # Rodion Martynov +# Maxwell Schmidt # License: BSD 3 clause from collections.abc import Iterable @@ -1153,6 +1154,282 @@ def split(self, X, y=None, groups=None): ) +class RollingWindowCV(_BaseKFold): + """ + A variant of TimeSeriesSplit which yields equally sized rolling windows, which + allows for more consistent parameter tuning. + + If a time column is passed then the windows will be sized according to the time + steps given without blending (this is useful for longitudinal data). + + Parameters + ---------- + n_splits : int, default=5 + Number of splits. + + time_column : Iterable, default=None + Column of the dataset containing dates. Will function identically with `None` + when observations are not longitudinal. If observations are longitudinal then + will facilitate splitting train and validation without date bleeding. + + train_prop : float, default=0.8 + Proportion of each window which should be allocated to training. If + `buffer_prop` is given then true training proportion will be + `train_prop - buffer_prop`. + Validation proportion will always be `1 - train_prop`. + + buffer_prop : float, default=0.0 + The proportion of each window which should be allocated to nothing. Cuts into + `train_prop`. + + slide : float, default=0.0 + `slide + 1` is the number of validation lenghts to step by when generating + windows. A value between -1.0 and 0.0 will create nearly stationary windows, + and should be avoided unless for some odd reason it is needed. + + bias : {'left', 'right', 'train'}, default='train' + A 'left' `bias` will yeld indicies beginning at 0 and not necessarily ending + at N. A 'right' `bias` will yield indicies not necessarily beginning with 0 but + will however end at N. A 'train' `bias` will yield indices from 0 to N, with + the overhang which would have been present with 'right' or 'left' `bias` + allocated to the training window. + + max_long_samples : int, default=None + If the data is longitudinal and this variable is given, the number of + observations at each time step will be limited to the first `max_long_samples` + samples. + + expanding : bool, default=False + When `True` each window will begin with the first time step. This will yeild + training indicies which increase in number as the window moves forwards. + + Examples + -------- + >>> import numpy as np + >>> from sklearn.model_selection import RollingWindowCV + >>> X = np.random.randn(20, 2) + >>> y = np.random.randint(0, 2, 20) + >>> rwcv = RollingWindowCV(n_splits=5, bias="right") + >>> for train_index, test_index in rwcv.split(X): + ... print("TRAIN:", train_index, "TEST:", test_index) + ... X_train, X_test = X[train_index], X[test_index] + ... y_train, y_test = y[train_index], y[test_index] + TRAIN: [1 2 3 4 5 6 7 8 9] TEST: [10 11] + TRAIN: [ 3 4 5 6 7 8 9 10 11] TEST: [12 13] + TRAIN: [ 5 6 7 8 9 10 11 12 13] TEST: [14 15] + TRAIN: [ 7 8 9 10 11 12 13 14 15] TEST: [16 17] + TRAIN: [ 9 10 11 12 13 14 15 16 17] TEST: [18 19] + >>> # Use a time column with longitudinal data and reduce train proportion + >>> time_col = np.tile(np.arange(16), 2) + >>> X = np.arange(64).reshape(32, 2) + >>> y = np.arange(32) + >>> rwcv = RollingWindowCV( + ... time_column=time_col, train_prop=0.5, n_splits=5, bias='right' + ... ) + >>> for train_index, test_index in rwcv.split(X): + ... print("TRAIN:", train_index, "TEST:", test_index) + ... X_train, X_test = X[train_index], X[test_index] + ... y_train, y_test = y[train_index], y[test_index] + TRAIN: [ 1 17 2 18 3 19 4 20 5 21] TEST: [ 6 22 7 23] + TRAIN: [ 3 19 4 20 5 21 6 22 7 23] TEST: [ 8 24 9 25] + TRAIN: [ 5 21 6 22 7 23 8 24 9 25] TEST: [10 26 11 27] + TRAIN: [ 7 23 8 24 9 25 10 26 11 27] TEST: [12 28 13 29] + TRAIN: [ 9 25 10 26 11 27 12 28 13 29] TEST: [14 30 15 31] + >>> # Bias the indicies to the start of the time column + >>> rwcv = RollingWindowCV( + ... time_column=time_col, train_prop=0.5, n_splits=5, bias='left' + ... ) + >>> for train_index, test_index in rwcv.split(X): + ... print("TRAIN:", train_index, "TEST:", test_index) + ... X_train, X_test = X[train_index], X[test_index] + ... y_train, y_test = y[train_index], y[test_index] + TRAIN: [ 0 16 1 17 2 18 3 19 4 20] TEST: [ 5 21 6 22] + TRAIN: [ 2 18 3 19 4 20 5 21 6 22] TEST: [ 7 23 8 24] + TRAIN: [ 4 20 5 21 6 22 7 23 8 24] TEST: [ 9 25 10 26] + TRAIN: [ 6 22 7 23 8 24 9 25 10 26] TEST: [11 27 12 28] + TRAIN: [ 8 24 9 25 10 26 11 27 12 28] TEST: [13 29 14 30] + >>> # Introduce a buffer zone between train and validation, and slide window + >>> # by an additional validation size between windows. + >>> X = np.arange(25) + >>> Y = np.arange(25)[::-1] + >>> rwcv = RollingWindowCV( + ... train_prop=0.6, n_splits=2, buffer_prop=0.2, slide=1.0, bias="right" + ... ) + >>> for train_index, test_index in rwcv.split(X): + ... print("TRAIN:", train_index, "TEST:", test_index) + ... X_train, X_test = X[train_index], X[test_index] + ... y_train, y_test = y[train_index], y[test_index] + ... + TRAIN: [2 3 4 5 6 7] TEST: [10 11 12 13 14] + TRAIN: [12 13 14 15 16 17] TEST: [20 21 22 23 24] + """ + + def __init__( + self, + n_splits=4, + *, + time_column=None, + train_prop=0.8, + buffer_prop=0.0, + slide=0.0, + bias="train", + max_long_samples=None, + expanding=False, + ): + if buffer_prop > train_prop: + raise ValueError( + "Buffer proportion cannot be greater than training proportion." + ) + if slide < -1.0: + raise ValueError("slide cannot be less than -1.0") + if bias not in ("right", "left", "train"): + raise ValueError("Invalid value for bias.") + + self.n_splits = n_splits + self.time_column = time_column + self.train_prop = train_prop + self.buffer_prop = buffer_prop + test_prop = 1 - train_prop + self.batch_size = (1 + (test_prop * (slide + 1) * (n_splits - 1))) ** (-1) + self.slide = slide + self.bias = bias + if max_long_samples is not None: + max_long_samples += 1 # index slice end is exclusivve + self.max_long_samples = max_long_samples + self.expanding = expanding + + def split(self, X, y=None, groups=None): + """Generate indices to split data into training and test set. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training data, where `n_samples` is the number of samples + and `n_features` is the number of features. + + y : array-like of shape (n_samples,) + Always ignored, exists for compatibility. + + groups : array-like of shape (n_samples,) + Always ignored, exists for compatibility. + + Yields + ------ + train : ndarray + The training set indices for that split. + + test : ndarray + The testing set indices for that split. + """ + if self.time_column is None: + X, y, groups = indexable(X, y, groups) + n_samples = _num_samples(X) + else: + X = self.time_column + X, y, groups = indexable(X, y, groups) + X_unique = np.array(list(dict.fromkeys(X))) + n_samples = _num_samples(X_unique) + + if self.n_splits > n_samples: + raise ValueError( + f"Cannot have number of folds={self.n_splits} greater" + f" than the number of samples={n_samples}." + ) + + if isinstance(self.batch_size, float) and self.batch_size < 1: + length_per_iter = int(n_samples * self.batch_size) + elif isinstance(self.batch_size, int) and self.batch_size >= 1: + length_per_iter = self.batch_size + else: + raise ValueError( + "batch_size must be decimal between 0 and 1.0 or whole number greater " + f"than or equal to 1 (got {self.batch_size})." + ) + + test_size = int(length_per_iter * (1 - self.train_prop)) + if test_size < 1: + raise ValueError( + "Inferred batch size with batch test proportion of " + f"{1 - self.train_prop:0.2f}, slide of {self.slide:0.2f}, and " + f"n_splits of {self.n_splits} is {length_per_iter}. Each batches " + "testing length is thus " + f"{length_per_iter * (1 - self.train_prop):0.2f}, which must not be " + "less than 1.0" + ) + buffer_size = int(length_per_iter * self.buffer_prop) + train_size = length_per_iter - test_size - buffer_size + + used_indices_len = ( + test_size * (self.slide + 1) * (self.n_splits - 1) + length_per_iter + ) + # difference is expected to be 1 or 0, so only effects data sets with + # very few samples. + if n_samples - used_indices_len >= test_size: + train_size += test_size + length_per_iter += test_size + + if self.bias == "left": + train_starts = range( + 0, n_samples - length_per_iter + 1, int(test_size * (self.slide + 1)) + ) + else: + overhang = (n_samples - length_per_iter) % int(test_size * (self.slide + 1)) + if self.bias == "right": + train_starts = range( + overhang, + n_samples - length_per_iter + 1, + int(test_size * (self.slide + 1)), + ) + elif self.bias == "train": + length_per_iter += overhang + train_size += overhang + train_starts = range( + 0, + n_samples - length_per_iter + 1, + int(test_size * (self.slide + 1)), + ) + + if self.time_column is None: + indices = np.arange(n_samples) + for train_start in train_starts: + yield ( + indices[ + 0 if self.expanding else train_start : train_start + train_size + ], + indices[ + train_start + + train_size + + buffer_size : train_start + + length_per_iter + ], + ) + else: + for train_start in train_starts: + yield ( + np.concatenate( + [ + np.argwhere(X == x_u).flatten()[: self.max_long_samples] + for x_u in X_unique[ + 0 + if self.expanding + else train_start : train_start + train_size + ] + ] + ), + np.concatenate( + [ + np.argwhere(X == x_u).flatten()[: self.max_long_samples] + for x_u in X_unique[ + train_start + + train_size + + buffer_size : train_start + + length_per_iter + ] + ] + ), + ) + + class LeaveOneGroupOut(BaseCrossValidator): """Leave One Group Out cross-validator diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index f502ebc8a3b6a..467099c0d0b97 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -22,6 +22,7 @@ from sklearn.model_selection import StratifiedKFold from sklearn.model_selection import GroupKFold from sklearn.model_selection import TimeSeriesSplit +from sklearn.model_selection import RollingWindowCV from sklearn.model_selection import LeaveOneOut from sklearn.model_selection import LeaveOneGroupOut from sklearn.model_selection import LeavePOut @@ -172,6 +173,7 @@ def test_2d_y(): LeavePGroupsOut(n_groups=2), GroupKFold(n_splits=3), TimeSeriesSplit(), + RollingWindowCV(), PredefinedSplit(test_fold=groups), ] for splitter in splitters: @@ -1782,6 +1784,116 @@ def test_time_series_gap(): next(splits) +def test_rolling_window_cv(): + X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] + + # Should fail if there are more folds than samples + with pytest.raises(ValueError, match="Cannot have number of folds.*greater"): + next(RollingWindowCV(n_splits=9).split(X)) + + with pytest.raises(ValueError): + RollingWindowCV(train_prop=0.8, buffer_prop=0.9) + + with pytest.raises(ValueError): + RollingWindowCV(slide=-2.0) + + with pytest.raises(ValueError): + next(RollingWindowCV(train_prop=0.99).split(X)) + + with pytest.raises(ValueError): + RollingWindowCV(bias="up") + + rwcv = RollingWindowCV(2, bias="right") + + # Manually check that Rolling Window CV preserves the data + # ordering on toy datasets + splits = rwcv.split(X) + train, test = next(splits) + assert_array_equal(train, [0, 1, 2, 3, 4, 5]) + assert_array_equal(test, [6]) + + train, test = next(splits) + assert_array_equal(train, [1, 2, 3, 4, 5, 6]) + assert_array_equal(test, [7]) + + # Check get_n_splits returns the correct number of splits + splits = RollingWindowCV(2, bias="right").split(X) + n_splits_actual = len(list(splits)) + assert n_splits_actual == rwcv.get_n_splits() + assert n_splits_actual == 2 + + +def test_rolling_window_longitudinal(): + time_col = np.tile(np.arange(8), 2) + X = np.arange(32).reshape(16, 2) + + splits = RollingWindowCV( + time_column=time_col, + train_prop=0.5, + n_splits=3, + max_long_samples=2, + bias="right", + ).split(X) + train, test = next(splits) + assert_array_equal(train, [0, 8, 1, 9]) + assert_array_equal(test, [2, 10, 3, 11]) + train, test = next(splits) + assert_array_equal(train, [2, 10, 3, 11]) + assert_array_equal(test, [4, 12, 5, 13]) + train, test = next(splits) + assert_array_equal(train, [4, 12, 5, 13]) + assert_array_equal(test, [6, 14, 7, 15]) + + +def test_rolling_window_params(): + X = np.zeros((40, 1)) + + # slide + splits = RollingWindowCV(2, slide=1.0, bias="right").split(X) + train, test = next(splits) + assert_array_equal(train, np.arange(2, 25)) + assert_array_equal(test, np.arange(25, 30)) + train, test = next(splits) + assert_array_equal(train, np.arange(12, 35)) + assert_array_equal(test, np.arange(35, 40)) + + # left bias + splits = RollingWindowCV(2, slide=1.0, bias="left").split(X) + train, test = next(splits) + assert_array_equal(train, np.arange(0, 23)) + assert_array_equal(test, np.arange(23, 28)) + train, test = next(splits) + assert_array_equal(train, np.arange(10, 33)) + assert_array_equal(test, np.arange(33, 38)) + + # train bias + splits = RollingWindowCV(2, slide=1.0, bias="train").split(X) + train, test = next(splits) + assert_array_equal(train, np.arange(0, 25)) + assert_array_equal(test, np.arange(25, 30)) + train, test = next(splits) + assert_array_equal(train, np.arange(10, 35)) + assert_array_equal(test, np.arange(35, 40)) + + # buffer + splits = RollingWindowCV(2, slide=1.0, buffer_prop=0.1, bias="right").split(X) + train, test = next(splits) + assert_array_equal(train, np.arange(2, 23)) + assert_array_equal(test, np.arange(25, 30)) + train, test = next(splits) + assert_array_equal(train, np.arange(12, 33)) + assert_array_equal(test, np.arange(35, 40)) + + # expanding + splits = RollingWindowCV(2, slide=1.0, bias="train", expanding=True).split(X) + train, test = next(splits) + assert_array_equal(train, np.arange(0, 25)) + assert_array_equal(test, np.arange(25, 30)) + train, test = next(splits) + assert_array_equal(train, np.arange(0, 35)) + assert_array_equal(test, np.arange(35, 40)) + + def test_nested_cv(): # Test if nested cross validation works with different combinations of cv rng = np.random.RandomState(0)