8000 Add homogeneous-time-series-cv · scikit-learn/scikit-learn@75e7401 · GitHub
[go: up one dir, main page]

Skip to content

Commit 75e7401

Browse files
committed
Add homogeneous-time-series-cv
1 parent 6b1d351 commit 75e7401

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

sklearn/model_selection/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ._split import KFold
33
from ._split import LabelKFold
44
from ._split import StratifiedKFold
5+
from ._split import HomogeneousTimeSeriesCV
56
from ._split import LeaveOneLabelOut
67
from ._split import LeaveOneOut
78
from ._split import LeavePLabelOut
@@ -27,6 +28,7 @@
2728

2829
__all__ = ('BaseCrossValidator',
2930
'GridSearchCV',
31+
'HomogeneousTimeSeriesCV',
3032
'KFold',
3133
'LabelKFold',
3234
'LabelShuffleSplit',

sklearn/model_selection/_split.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,92 @@ def split(self, X, y, labels=None):
637637
"""
638638
return super(StratifiedKFold, self).split(X, y, labels)
639639

640+
class HomogeneousTimeSeriesCV(_BaseKFold):
641+
"""Homogeneous Time Series cross-validator
642+
643+
Provides train/test indices to split time series data in train/test sets.
644+
645+
This cross-validation object is a variation of KFold.
646+
In iteration k, it returns first k folds as train sets and k+1 fold as
647+
test sets.
648+
649+
Read more in the :ref:`User Guide <cross_validation>`.
650+
651+
Parameters
652+
----------
653+
n_folds : int, default=3
654+
Number of folds. Must be at least 2.
655+
656+
Examples
657+
--------
658+
>>> from sklearn.model_selection import HomogeneousTimeSeriesCV
659+
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
660+
>>> y = np.array([1, 2, 3, 4])
661+
>>> htscv = HomogeneousTimeSeriesCV(n_folds=4)
662+
>>> htscv.get_n_splits(X)
663+
3
664+
>>> print(htscv) # doctest: +NORMALIZE_WHITESPACE
665+
KFold(n_folds=2, random_state=None, shuffle=False)
666+
>>> for train_index, test_index in htscv.split(X):
667+
... print("TRAIN:", train_index, "TEST:", test_index)
668+
... X_train, X_test = X[train_index], X[test_index]
669+
... y_train, y_test = y[train_index], y[test_index]
670+
TRAIN: [0] TEST: [1]
671+
TRAIN: [0 1] TEST: [2]
672+
TRAIN: [1 2 3] TEST: [3]
673+
674+
Notes
675+
-----
676+
The first ``n_samples % n_folds`` folds have size
677+
``n_samples // n_folds + 1``, other folds have size
678+
``n_samples // n_folds``, where ``n_samples`` is the number of samples.
679+
680+
Number of splitting iterations in this cross-validator, n_folds-1,
681+
is not equal to other KFold based cross-validators'.
682+
683+
See also
684+
--------
685+
"""
686+
def __init__(self, n_folds=3):
687+
super(HomogeneousTimeSeriesCV, self).__init__(n_folds,
688+
shuffle=False,
689+
random_state=None)
690+
691+
def split(self, X, y=None, labels=None):
692+
n_samples = _num_samples(X)
693+
n_folds = self.n_folds
694+
indices = np.arange(n_samples)
695+
fold_sizes = (n_samples // n_folds) * np.ones(n_folds, dtype=np.int)
696+
fold_sizes[:n_samples % n_folds] += 1
697+
current = 0
698+
for fold_size in fold_sizes:
699+
start, stop = current, current + fold_size
700+
if current != 0:
701+
yield indices[:start], indices[start:stop]
702+
current = stop
703+
704+
def get_n_splits(self, X=None, y=None, labels=None):
705+
"""Returns the number of splitting iterations in the cross-validator
706+
707+
Parameters
708+
----------
709+
X : object
710+
Always ignored, exists for compatibility.
711+
712+
y : object
713+
Always ignored, exists for compatibility.
714+
715+
labels : object
716+
Always ignored, exists for compatibility.
717+
718+
Returns
719+
-------
720+
n_splits : int
721+
Returns the number of splitting iterations in the cross-validator.
722+
"""
723+
return self.n_folds-1
724+
725+
640726
class LeaveOneLabelOut(BaseCrossValidator):
641727
"""Leave One Label Out cross-validator
642728

0 commit comments

Comments
 (0)
0