@@ -637,6 +637,121 @@ def split(self, X, y, labels=None):
637
637
"""
638
638
return super (StratifiedKFold , self ).split (X , y , labels )
639
639
640
+ class HomogeneousTimeSeriesCV (_BaseKFold ):
641
+ """Homogeneous Time Series cross-validator
642
+
643
+ Provides train/test indices to split time series data in train/test sets.
644
+
645
+ This cross-validation object is a variation of KFold.
646
+ In iteration k, it returns first k folds as train set and k+1 fold as
647
+ test set.
648
+
649
+ Read more in the :ref:`User Guide <cross_validation>`.
650
+
651
+ Parameters
652
+ ----------
653
+ n_folds : int, default=3
654
+ Number of folds. Must be at least 2.
655
+
656
+ Examples
657
+ --------
658
+ >>> from sklearn.model_selection import HomogeneousTimeSeriesCV
659
+ >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
660
+ >>> y = np.array([1, 2, 3, 4])
661
+ >>> htscv = HomogeneousTimeSeriesCV(n_folds=4)
662
+ >>> htscv.get_n_splits(X)
663
+ 3
664
+ >>> print(htscv) # doctest: +NORMALIZE_WHITESPACE
665
+ KFold(n_folds=2, random_state=None, shuffle=False)
666
+ >>> for train_index, test_index in htscv.split(X):
667
+ ... print("TRAIN:", train_index, "TEST:", test_index)
668
+ ... X_train, X_test = X[train_index], X[test_index]
669
+ ... y_train, y_test = y[train_index], y[test_index]
670
+ TRAIN: [0] TEST: [1]
671
+ TRAIN: [0 1] TEST: [2]
672
+ TRAIN: [1 2 3] TEST: [3]
673
+
674
+ Notes
675
+ -----
676
+ The first ``n_samples % n_folds`` folds have size
677
+ ``n_samples // n_folds + 1``, other folds have size
678
+ ``n_samples // n_folds``, where ``n_samples`` is the number of samples.
679
+
680
+ Number of splitting iterations in this cross-validator, n_folds-1,
681
+ is not equal to other KFold based cross-validators'.
682
+
683
+ See also
684
+ --------
685
+ """
686
+ def __init__ (self , n_folds = 3 ):
687
+ super (HomogeneousTimeSeriesCV , self ).__init__ (n_folds ,
688
+ shuffle = False ,
689
+ random_state = None )
690
+
691
+ def split (self , X , y = None , labels = None ):
692
+ """Generate indices to split data into training and test set.
693
+
694
+ Parameters
695
+ ----------
696
+ X : array-like, shape (n_samples, n_features)
697
+ Training data, where n_samples is the number of samples
698
+ and n_features is the number of features.
699
+
700
+ y : array-like, shape (n_samples,)
701
+ The target variable for supervised learning problems.
702
+
703
+ labels : array-like, with shape (n_samples,), optional
704
+ Group labels for the samples used while splitting the dataset into
705
+ train/test set.
706
+
707
+ Returns
708
+ -------
709
+ train : ndarray
710
+ The training set indices for that split.
711
+
712
+ test : ndarray
713
+ The testing set indices for that split.
714
+ """
715
+ X , y , labels = indexable (X , y , labels )
716
+ n_samples = _num_samples (X )
717
+ if self .n_folds > n_samples :
718
+ raise ValueError (
719
+ ("Cannot have number of folds n_folds={0} greater"
720
+ " than the number of samples: {1}." ).format (self .n_folds ,
721
+ n_samples ))
722
+ n_folds = self .n_folds
723
+ indices = np .arange (n_samples )
724
+ fold_sizes = (n_samples // n_folds ) * np .ones (n_folds , dtype = np .int )
725
+ fold_sizes [:n_samples % n_folds ] += 1
726
+ current = 0
727
+ for fold_size in fold_sizes :
728
+ start , stop = current , current + fold_size
729
+ if current != 0 :
730
+ yield indices [:start ], indices [start :stop ]
731
+ current = stop
732
+
733
+ def get_n_splits (self , X = None , y = None , labels = None ):
734
+ """Returns the number of splitting iterations in the cross-validator
735
+
736
+ Parameters
737
+ ----------
738
+ X : object
739
+ Always ignored, exists for compatibility.
740
+
741
+ y : object
742
+ Always ignored, exists for compatibility.
743
+
744
+ labels : object
745
+ Always ignored, exists for compatibility.
746
+
747
+ Returns
748
+ -------
749
+ n_splits : int
750
+ Returns the number of splitting iterations in the cross-validator.
751
+ """
752
+ return self .n_folds - 1
753
+
754
+
640
755
class LeaveOneLabelOut (BaseCrossValidator ):
641
756
"""Leave One Label Out cross-validator
642
757
0 commit comments