8000 Add GroupTimeSeriesSplit with params - gap and test_size by soso-song · Pull Request #19996 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Add GroupTimeSeriesSplit with params - gap and test_size #19996

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions doc/modules/cross_validation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,76 @@ Here is a visualization of the cross-validation behavior.
:align: center
:scale: 75%

Group Time Series Split
^^^^^^^^^^^^^^^^^^^^^^^

:class:`GroupTimeSeriesSplit` combines :class:`TimeSeriesSplit` with the Group awareness
of `GroupKFold`. Like :class:`TimeSeriesSplit` this also returns first :math:`k` folds
as train set and the :math:`(k+1)` th fold as test set.
Successive training sets are supersets of those that come before them.
Also, it adds all surplus data to the first training partition, which
is always used to train the model.
This class can be used to cross-validate time series data samples
that are observed at fixed time intervals.

The same group will not appear in two different folds (the number of
distinct groups has to be at least equal to the number of folds).

The groups should be Continuous like below.
['a', 'a', 'a', 'a', 'a', 'b', 'b', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'd', 'd', 'd']

Non-continuous groups like below will give an error.
['a', 'a', 'a', 'a', 'a', 'b','b', 'b', 'b', 'b', 'b', 'a', 'c', 'c', 'c', 'b', 'd', 'd']

`GroupTimeSeriesSplit` is useful in cases where we have time series data for
say multiple days with multiple data points within a day.
During cross-validation we may not want the training days to be be used in testing.
Here the days can act as groups to keep the training and test splits separate.

Example of 3-split time series cross-validation on a dataset with
18 samples and 4 groups::

>>> import numpy as np
>>> from sklearn.model_selection import GroupTimeSeriesSplit
>>> groups = np.array(['a', 'a', 'a', 'a', 'a', 'a',
... 'b', 'b', 'b', 'b', 'b',
... 'c', 'c', 'c', 'c',
... 'd', 'd', 'd'])
>>> gtss = GroupTimeSeriesSplit(n_splits=3)
>>> for train_idx, test_idx in gtss.split(groups, groups=groups):
... print("TRAIN:", train_idx, "TEST:", test_idx)
... print("TRAIN GROUP:", groups[train_idx],
... "TEST GROUP:", groups[test_idx])
TRAIN: [0, 1, 2, 3, 4, 5] TEST: [6, 7, 8, 9, 10]
TRAIN GROUP: ['a' 'a' 'a' 'a' 'a' 'a']
TEST GROUP: ['b' 'b' 'b' 'b' 'b']
TRAIN: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] TEST: [11, 12, 13, 14]
TRAIN GROUP: ['a' 'a' 'a' 'a' 'a' 'a' 'b' 'b' 'b' 'b' 'b']
TEST GROUP: ['c' 'c' 'c' 'c']
TRAIN: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
TEST: [15, 16, 17]
TRAIN GROUP: ['a' 'a' 'a' 'a' 'a' 'a' 'b' 'b' 'b' 'b' 'b' 'c' 'c' 'c' 'c']
TEST GROUP: ['d' 'd' 'd']

Example of 2-split time series cross-validation on a dataset with
18 samples and 4 groups and 1 test_size and 3 max_train_size and 1 period gap::

>>> import numpy as np
>>> from sklearn.model_selection import GroupTimeSeriesSplit
>>> groups = np.array(['a', 'a', 'a', 'a', 'a', 'a',
... 'b', 'b', 'b', 'b', 'b',
... 'c', 'c', 'c', 'c',
... 'd', 'd', 'd'])
>>> gtss = GroupTimeSeriesSplit(n_splits=2, test_size=1, gap=1, max_train_size=3)
>>> for train_idx, test_idx in gtss.split(groups, groups=groups):
... print("TRAIN:", train_idx, "TEST:", test_idx)
... print("TRAIN GROUP:", groups[train_idx], "TEST GROUP:", groups[test_idx])
TRAIN: [0, 1, 2, 3, 4, 5] TEST: [11, 12, 13, 14]
TRAIN GROUP: ['a' 'a' 'a' 'a' 'a' 'a'] TEST GROUP: ['c' 'c' 'c' 'c']
TRAIN: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] TEST: [15, 16, 17]
TRAIN GROUP: ['a' 'a' 'a' 'a' 'a' 'a' 'b' 'b' 'b' 'b' 'b']
TEST GROUP: ['d' 'd' 'd']

A note on shuffling
===================

Expand Down
121 changes: 121 additions & 0 deletions doc/modules/group_time_series_split.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@

.. _GroupTimeSeriesSplit:

=================================================
sklearn.model_selection.GroupTimeSeriesSplit
=================================================
.. code-block:: python

class sklearn.model_selection.GroupTimeSeriesSplit(n_splits=5, *, max_train_size=None, test_size=None, gap=0)

| *GroupTimeSeriesSplit* combines *TimeSeriesSplit* with the Group awareness of *GroupKFold*.
|
| Like *TimeSeriesSplit* this also returns first *k* folds as train set and the *(k+1)* th fold as test set.
|
| Since the Group applies on this class, the same group will not appear in two different
folds(the number of distinct groups has to be at least equal to the number of folds) which make sure the i.i.d. assumption will not be broken.

| All operations of this CV strategy are done at the group level.
| So all our parameters, not limited to splits, including test_size, gap, and max_train_size, all represent the constraints on the number of groups.


Parameters:
-----------
| **n_splits;int,default=5**
|
| Number of splits. Must be at least 2.
|
| **max_train_size:int, default=None**
|
| Maximum number of groups for a single training set.
|
| **test_size:int, default=None**
|
| Used to limit the number of groups in the test set. Defaults to ``n_samples // (n_splits + 1)``, which is the maximum allowed value with ``gap=0``.
|
| **gap:int, default=0**
|
| Number of groups in samples to exclude from the end of each train set before the test set.

Example 1:
---------
.. code-block:: python

>>> import numpy as np
>>> from sklearn.model_selection import GroupTimeSeriesSplit
>>> groups = np.array(['a', 'a', 'a', 'a', 'a', 'a',
... 'b', 'b', 'b', 'b', 'b',
... 'c', 'c', 'c', 'c',
... 'd', 'd', 'd'])
>>> gtss = GroupTimeSeriesSplit(n_splits=3)
>>> for train_idx, test_idx in gtss.split(groups, groups=groups):
... print("TRAIN:", train_idx, "TEST:", test_idx)
... print("TRAIN GROUP:", groups[train_idx],
... "TEST GROUP:", groups[test_idx])
TRAIN: [0, 1, 2, 3, 4, 5] TEST: [6, 7, 8, 9, 10]
TRAIN GROUP: ['a' 'a' 'a' 'a' 'a' 'a']
TEST GROUP: ['b' 'b' 'b' 'b' 'b']
TRAIN: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] TEST: [11, 12, 13, 14]
TRAIN GROUP: ['a' 'a' 'a' 'a' 'a' 'a' 'b' 'b' 'b' 'b' 'b']
TEST GROUP: ['c' 'c' 'c' 'c']
TRAIN: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
TEST: [15, 16, 17]
TRAIN GROUP: ['a' 'a' 'a' 'a' 'a' 'a' 'b' 'b' 'b' 'b' 'b' 'c' 'c' 'c' 'c']
TEST GROUP: ['d' 'd' 'd']

Example 2:
---------
.. code-block:: python

>>> import numpy as np
>>> from sklearn.model_selection import GroupTimeSeriesSplit
>>> groups = np.array(['a', 'a', 'a', 'a', 'a', 'a',\
'b', 'b', 'b', 'b', 'b',\
'c', 'c', 'c', 'c',\
'd', 'd', 'd'])
>>> gtss = GroupTimeSeriesSplit(n_splits=2, test_size=1, gap=1,\
max_train_size=3)
>>> for train_idx, test_idx in gtss.split(groups, groups=groups):
... print("TRAIN:", train_idx, "TEST:", test_idx)
... print("TRAIN GROUP:", groups[train_idx],\
"TEST GROUP:", groups[test_idx])
TRAIN: [0, 1, 2, 3, 4, 5] TEST: [11, 12, 13, 14]
TRAIN GROUP: ['a' 'a' 'a' 'a' 'a' 'a'] TEST GROUP: ['c' 'c' 'c' 'c']
TRAIN: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] TEST: [15, 16, 17]
TRAIN GROUP: ['a' 'a' 'a' 'a' 'a' 'a' 'b' 'b' 'b' 'b' 'b']
TEST GROUP: ['d' 'd' 'd']

Methods:
--------
| **get_n_splits([X, y, groups])**
|
| Returns the number of splitting iterations in the cross-validator
| *Parameters:*
| *X: object*
| Always ignored, exists for compatibility.
| *y: object*
| Always ignored, exists for compatibility.
| *groups: object*
| Always ignored, exists for compatibility.
| *Returns:*
| *n_splits: int*
| Returns the number of splitting iterations in the cross-validator.
|
| **split(X[groups, y])**
|
| Generate indices to split data into training and test set by group.
| *Parameters:*
| *X : array-like of shape (n_samples, n_features)*
| Training data, where n_samples is the number of samples
| and n_features is the number of features.
| *y : array-like of shape (n_samples,)*
| Always ignored, exists for compatibility.
| *groups : array-like of shape (n_samples,)*
| Group labels for the samples used while splitting the dataset into
| train/test set.
| *Yields:*
| *train : ndarray*
| The training set indices for that split.
| *test : ndarray*
| The testing set indices for that split.

10 changes: 6 additions & 4 deletions examples/model_selection/plot_cv_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
for comparison.
"""

from sklearn.model_selection import (TimeSeriesSplit, KFold, ShuffleSplit,
StratifiedKFold, GroupShuffleSplit,
GroupKFold, StratifiedShuffleSplit,
from sklearn.model_selection import (TimeSeriesSplit, GroupTimeSeriesSplit,
KFold, ShuffleSplit, StratifiedKFold,
GroupShuffleSplit, GroupKFold,
StratifiedShuffleSplit,
StratifiedGroupKFold)
import numpy as np
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -151,7 +152,8 @@ def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
# Note how some use the group/class information while others do not.

cvs = [KFold, GroupKFold, ShuffleSplit, StratifiedKFold, StratifiedGroupKFold,
GroupShuffleSplit, StratifiedShuffleSplit, TimeSeriesSplit]
GroupShuffleSplit, StratifiedShuffleSplit, TimeSeriesSplit,
GroupTimeSeriesSplit]


for cv in cvs:
Expand Down
2 changes: 2 additions & 0 deletions sklearn/model_selection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ._split import BaseCrossValidator
from ._split import KFold
from ._split import GroupKFold
from ._split import GroupTimeSeriesSplit
from ._split import StratifiedKFold
from ._split import TimeSeriesSplit
from ._split import LeaveOneGroupOut
Expand Down Expand Up @@ -46,6 +47,7 @@
'KFold',
'GroupKFold',
'GroupShuffleSplit',
'GroupTimeSeriesSplit',
'LeaveOneGroupOut',
'LeaveOneOut',
'LeavePGroupsOut',
Expand Down
Loading
0