8000 Add homogeneous-time-series-cv · scikit-learn/scikit-learn@04b9e79 · GitHub
[go: up one dir, main page]

Skip to content

Commit 04b9e79

Browse files
committed
Add homogeneous-time-series-cv
1 parent 6b1d351 commit 04b9e79

File tree

2 files changed

+117
-0
lines changed

2 files changed

+117
-0
lines changed

sklearn/model_selection/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ._split import KFold
33
from ._split import LabelKFold
44
from ._split import StratifiedKFold
5+
from ._split import HomogeneousTimeSeriesCV
56
from ._split import LeaveOneLabelOut
67
from ._split import LeaveOneOut
78
from ._split import LeavePLabelOut
@@ -27,6 +28,7 @@
2728

2829
__all__ = ('BaseCrossValidator',
2930
'GridSearchCV',
31+
'HomogeneousTimeSeriesCV',
3032
'KFold',
3133
'LabelKFold',
3234
'LabelShuffleSplit',

sklearn/model_selection/_split.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,121 @@ def split(self, X, y, labels=None):
637637
"""
638638
return super(StratifiedKFold, self).split(X, y, labels)
639639

640+
class HomogeneousTimeSeriesCV(_BaseKFold):
641+
"""Homogeneous Time Series cross-validator
642+
643+
Provides train/test indices to split time series data in train/test sets.
644+
645+
This cross-validation object is a variation of KFold.
646+
In iteration k, it returns first k folds as train set and k+1 fold as
647+
test set.
648+
649+
Read more in the :ref:`User Guide <cross_validation>`.
650+
651+
Parameters
652+
----------
653+
n_folds : int, default=3
654+
Number of folds. Must be at least 2.
655+
656+
Examples
657+
--------
658+
>>> from sklearn.model_selection import HomogeneousTimeSeriesCV
659+
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
660+
>>> y = np.array([1, 2, 3, 4])
661+
>>> htscv = HomogeneousTimeSeriesCV(n_folds=4)
662+
>>> htscv.get_n_splits(X)
663+
3
664+
>>> print(htscv) # doctest: +NORMALIZE_WHITESPACE
665+
KFold(n_folds=2, random_state=None, shuffle=False)
666+
>>> for train_index, test_index in htscv.split(X):
667+
... print("TRAIN:", train_index, "TEST:", test_index)
668+
... X_train, X_test = X[train_index], X[test_index]
669+
... y_train, y_test = y[train_index], y[test_index]
670+
TRAIN: [0] TEST: [1]
671+
TRAIN: [0 1] TEST: [2]
672+
TRAIN: [1 2 3] TEST: [3]
673+
674+
Notes
675+
-----
676+
The first ``n_samples % n_folds`` folds have size
677+
``n_samples // n_folds + 1``, other folds have size
678+
``n_samples // n_folds``, where ``n_samples`` is the number of samples.
679+
680+
Number of splitting iterations in this cross-validator, n_folds-1,
681+
is not equal to other KFold based cross-validators'.
682+
683+
See also
684+
--------
685+
"""
686+
def __init__(self, n_folds=3):
687+
super(HomogeneousTimeSeriesCV, self).__init__(n_folds,
688+
shuffle=False,
689+
random_state=None)
690+
691+
def split(self, X, y=None, labels=None):
692+
"""Generate indices to split data into training and test set.
693+
694+
Parameters
695+
----------
696+
X : array-like, shape (n_samples, n_features)
697+
Training data, where n_samples is the number of samples
698+
and n_features is the number of features.
699+
700+
y : array-like, shape (n_samples,)
701+
The target variable for supervised learning problems.
702+
703+
labels : array-like, with shape (n_samples,), optional
704+
Group labels for the samples used while splitting the dataset into
705+
train/test set.
706+
707+
Returns
708+
-------
709+
train : ndarray
710+
The training set indices for that split.
711+
712+
test : ndarray
713+
The testing set indices for that split.
714+
"""
715+
X, y, labels = indexable(X, y, labels)
716+
n_samples = _num_samples(X)
717+
if self.n_folds > n_samples:
718+
raise ValueError(
719+
("Cannot have number of folds n_folds={0} greater"
720+
" than the number of samples: {1}.").format(self.n_folds,
721+
n_samples))
722+
n_folds = self.n_folds
723+
indices = np.arange(n_samples)
724+
fold_sizes = (n_samples // n_folds) * np.ones(n_folds, dtype=np.int)
725+
fold_sizes[:n_samples % n_folds] += 1
726+
current = 0
727+
for fold_size in fold_sizes:
728+
start, stop = current, current + fold_size
729+
if current != 0:
730+
yield indices[:start], indices[start:stop]
731+
current = stop
732+
733+
def get_n_splits(self, X=None, y=None, labels=None):
734+
"""Returns the number of splitting iterations in the cross-validator
735+
736+
Parameters
737+
----------
738+
X : object
739+
Always ignored, exists for compatibility.
740+
741+
y : object
742+
Always ignored, exists for compatibility.
743+
744+
labels : object
745+
Always ignored, exists for compatibility.
746+
747+
Returns
748+
-------
749+
n_splits : int
750+
Returns the number of splitting iterations in the cross-validator.
751+
"""
752+
return self.n_folds-1
753+
754+
640755
class LeaveOneLabelOut(BaseCrossValidator):
641756
"""Leave One Label Out cross-validator
642757

0 commit comments

Comments
 (0)
0