8000 feat(cv): enable group split in TimeSeriesCV · scikit-learn/scikit-learn@9ce813d · GitHub
[go: up one dir, main page]

Skip to content

Commit 9ce813d

Browse files
committed
feat(cv): enable group split in TimeSeriesCV
1 parent b5e55f7 commit 9ce813d

File tree

2 files changed

+161
-14
lines changed

2 files changed

+161
-14
lines changed

sklearn/model_selection/_split.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -761,17 +761,19 @@ class TimeSeriesSplit(_BaseKFold):
761761
762762
max_train_size : int, default=None
763763
Maximum size for a single training set.
764+
or maximum number of data groups when group is supplied
764765
765766
test_size : int, default=None
766767
Used to limit the size of the test set. Defaults to
767768
``n_samples // (n_splits + 1)``, which is the maximum allowed value
768-
with ``gap=0``.
769+
with ``gap=0``
770+
or numer of data groups when group is supplied
769771
770772
.. versionadded:: 0.24
771773
772774
gap : int, default=0
773-
Number of samples to exclude from the end of each train set before
774-
the test set.
775+
Number of samples or groups to exclude from the end of each train set
776+
before the test set.
775777
776778
.. versionadded:: 0.24
777779
@@ -845,8 +847,9 @@ def split(self, X, y=None, groups=None):
845847
y : array-like of shape (n_samples,)
846848
Always ignored, exists for compatibility.
847849
848-
groups : array-like of shape (n_samples,)
849-
Always ignored, exists for compatibility.
850+
groups : array-like of shape (n_samples,), default=None
851+
Group labels for the samples used while splitting the dataset into
852+
train/test set.
850853
851854
Yields
852855
------
@@ -856,7 +859,18 @@ def split(self, X, y=None, groups=None):
856859
test : ndarray
857860
The testing set indices for that split.
858861
"""
859-
X, y, groups = indexable(X, y, groups)
862+
samples, y, groups = indexable(X, y, groups)
863+
864+
if groups is None:
865+
X = samples
866+
cv_type = "samples"
867+
else:
868+
_, count_index, count = np.unique(groups, return_counts=True,
869+
return_index=True)
870+
X = np.argsort(count_index)
871+
cum_count = np.concatenate(([0], np.cumsum(count[X])))
872+
cv_type = "groups"
873+
860874
n_samples = _num_samples(X)
861875
n_splits = self.n_splits
862876
n_folds = n_splits + 1
@@ -868,24 +882,38 @@ def split(self, X, y=None, groups=None):
868882
if n_folds > n_samples:
869883
raise ValueError(
870884
(f"Cannot have number of folds={n_folds} greater"
871-
f" than the number of samples={n_samples}."))
885+
f" than the number of {cv_type}={n_samples}."))
872886
if n_samples - gap - (test_size * n_splits) <= 0:
873887
raise ValueError(
874-
(f"Too many splits={n_splits} for number of samples"
888+
(f"Too many splits={n_splits} for number of {cv_type}"
875889
f"={n_samples} with test_size={test_size} and gap={gap}."))
876890

877-
indices = np.arange(n_samples)
891+
if groups is None:
892+
indices = np.arange(n_samples)
893+
else:
894+
indices = np.arange(_num_samples(samples))
878895
test_starts = range(n_samples - n_splits * test_size,
879896
n_samples, test_size)
880897

881898
for test_start in test_starts:
882899
train_end = test_start - gap
883900
if self.max_train_size and self.max_train_size < train_end:
884-
yield (indices[train_end - self.max_train_size:train_end],
885-
indices[test_start:test_start + test_size])
901+
if groups is None:
902+
yield (indices[train_end - self.max_train_size:train_end],
903+
indices[test_start:test_start + test_size])
904+
else:
905+
yield (indices[cum_count[train_end - self.max_train_size]:
906+
cum_count[train_end]],
907+
indices[cum_count[test_start]:
908+
cum_count[test_start + test_size]])
886909
else:
887-
yield (indices[:train_end],
888-
indices[test_start:test_start + test_size])
910+
if groups is None:
911+
yield (indices[:train_end],
912+
indices[test_start:test_start + test_size])
913+
else:
914+
yield (indices[:cum_count[train_end]],
915+
indices[cum_count[test_start]:
916+
cum_count[test_start + test_size]])
889917

890918

891919
class LeaveOneGroupOut(BaseCrossValidator):

sklearn/model_selection/tests/test_split.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def test_2d_y():
138138
ShuffleSplit(), StratifiedShuffleSplit(test_size=.5),
139139
GroupShuffleSplit(), LeaveOneGroupOut(),
140140
LeavePGroupsOut(n_groups=2), GroupKFold(n_splits=3),
141-
TimeSeriesSplit(), PredefinedSplit(test_fold=groups)]
141+
TimeSeriesSplit(2),
142+
PredefinedSplit(test_fold=groups)]
142143
for splitter in splitters:
143144
list(splitter.split(X, y, groups))
144145
list(splitter.split(X, y_2d, groups))
@@ -1381,6 +1382,7 @@ def test_group_kfold():
13811382

13821383
def test_time_series_cv():
13831384
X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]]
1385+
groups = np.array([1, 1, 2, 2, 3, 4, 5])
13841386

13851387
# Should fail if there are more folds than samples
13861388
assert_raises_regexp(ValueError, "Cannot have number of folds.*greater",
@@ -1410,12 +1412,37 @@ def test_time_series_cv():
14101412
assert_array_equal(train, [0, 1, 2, 3, 4])
14111413
assert_array_equal(test, [5, 6])
14121414

1415+
# ordering on toy datasets with group
1416+
splits = tscv.split(X[:-1], groups=groups[:-1])
1417+
train, test = next(splits)
1418+
assert_array_equal(train, [0, 1, 2, 3])
1419+
assert_array_equal(test, [4])
1420+
1421+
train, test = next(splits)
1422+
assert_array_equal(train, [0, 1, 2, 3, 4])
1423+
assert_array_equal(test, [5])
1424+
1425+
splits = TimeSeriesSplit(2).split(X)
1426+
1427+
train, test = next(splits)
1428+
assert_array_equal(train, [0, 1, 2])
1429+
assert_array_equal(test, [3, 4])
1430+
1431+
train, test = next(splits)
1432+
assert_array_equal(train, [0, 1, 2, 3, 4])
1433+
assert_array_equal(test, [5, 6])
1434+
14131435
# Check get_n_splits returns the correct number of splits
14141436
splits = TimeSeriesSplit(2).split(X)
14151437
n_splits_actual = len(list(splits))
14161438
assert n_splits_actual == tscv.get_n_splits()
14171439
assert n_splits_actual == 2
14181440

1441+
splits = TimeSeriesSplit(2).split(X, groups=groups)
1442+
n_splits_actual = len(list(splits))
1443+
assert n_splits_actual == tscv.get_n_splits()
1444+
assert n_splits_actual == 2
1445+
14191446

14201447
def _check_time_series_max_train_size(splits, check_splits, max_train_size):
14211448
for (train, test), (check_train, check_test) in zip(splits, check_splits):
@@ -1427,21 +1454,39 @@ def _check_time_series_max_train_size(splits, check_splits, max_train_size):
14271454

14281455
def test_time_series_max_train_size():
14291456
X = np.zeros((6, 1))
1457+
groups = np.array([3, 4, 5, 1, 2, 2])
14301458
splits = TimeSeriesSplit(n_splits=3).split(X)
1459+
group_splits = TimeSeriesSplit(n_splits=3).split(X, groups=groups)
1460+
14311461
check_splits = TimeSeriesSplit(n_splits=3, max_train_size=3).split(X)
14321462
_check_time_series_max_train_size(splits, check_splits, max_train_size=3)
14331463

1464+
check_splits = TimeSeriesSplit(n_splits=3, max_train_size=3) \
1465+
.split(X, groups=groups)
1466+
_check_time_series_max_train_size(group_splits,
1467+
check_splits, max_train_size=3)
1468+
14341469
# Test for the case where the size of a fold is greater than max_train_size
14351470
check_splits = TimeSeriesSplit(n_splits=3, max_train_size=2).split(X)
14361471
_check_time_series_max_train_size(splits, check_splits, max_train_size=2)
14371472

1473+
check_splits = TimeSeriesSplit(n_splits=2, max_train_size=2) \
1474+
.split(X, groups=groups)
1475+
_check_time_series_max_train_size(group_splits,
1476+
check_splits, max_train_size=2)
1477+
14381478
# Test for the case where the size of each fold is less than max_train_size
14391479
check_splits = TimeSeriesSplit(n_splits=3, max_train_size=5).split(X)
14401480
_check_time_series_max_train_size(splits, check_splits, max_train_size=2)
14411481

1482+
check_splits = TimeSeriesSplit(n_splits=3, max_train_size=5).split(X)
1483+
_check_time_series_max_train_size(group_splits,
1484+
check_splits, max_train_size=2)
1485+
14421486

14431487
def test_time_series_test_size():
14441488
X = np.zeros((10, 1))
1489+
groups = np.array([6, 7, 1, 1, 1, 2, 2, 3, 4, 5])
14451490

14461491
# Test alone
14471492
splits = TimeSeriesSplit(n_splits=3, test_size=3).split(X)
@@ -1458,6 +1503,21 @@ def test_time_series_test_size():
14581503
assert_array_equal(train, [0, 1, 2, 3, 4, 5, 6])
14591504
assert_array_equal(test, [7, 8, 9])
14601505

1506+
# Test alone with groups
1507+
splits = TimeSeriesSplit(n_splits=3, test_size=2).split(X, groups=groups)
1508+
1509+
train, test = next(splits)
1510+
assert_array_equal(train, [0])
1511+
assert_array_equal(test, [1, 2, 3, 4])
1512+
1513+
train, test = next(splits)
1514+
assert_array_equal(train, [0, 1, 2, 3, 4])
1515+
assert_array_equal(test, [5, 6, 7])
1516+
1517+
train, test = next(splits)
1518+
assert_array_equal(train, [0, 1, 2, 3, 4, 5, 6, 7])
1519+
assert_array_equal(test, [8, 9])
1520+
14611521
# Test with max_train_size
14621522
splits = TimeSeriesSplit(n_splits=2, test_size=2,
14631523
max_train_size=4).split(X)
@@ -1470,14 +1530,31 @@ def test_time_series_test_size():
14701530
assert_array_equal(train, [4, 5, 6, 7])
14711531
assert_array_equal(test, [8, 9])
14721532

1533+
# Test with max_train_size and groups
1534+
splits = TimeSeriesSplit(n_splits=2, test_size=2,
1535+
max_train_size=2).split(X, groups=groups)
1536+
1537+
train, test = next(splits)
1538+
assert_array_equal(train, [1, 2, 3, 4])
1539+
assert_array_equal(test, [5, 6, 7])
1540+
1541+
train, test = next(splits)
1542+
assert_array_equal(train, [5, 6, 7])
1543+
assert_array_equal(test, [8, 9])
1544+
14731545
# Should fail with not enough data points for configuration
14741546
with pytest.raises(ValueError, match="Too many splits.*with test_size"):
14751547
splits = TimeSeriesSplit(n_splits=5, test_size=2).split(X)
14761548
next(splits)
1549+
with pytest.raises(ValueError, match="Too many splits.*with test_size"):
1550+
splits = TimeSeriesSplit(n_splits=5, test_size=2) \
1551+
.split(X, groups=groups)
1552+
next(splits)
14771553

14781554

14791555
def test_time_series_gap():
14801556
X = np.zeros((10, 1))
1557+
groups = np.array([6, 7, 1, 1, 1, 2, 2, 3, 4, 5])
14811558

14821559
# Test alone
14831560
splits = TimeSeriesSplit(n_splits=2, gap=2).split(X)
@@ -1490,6 +1567,17 @@ def test_time_series_gap():
14901567
assert_array_equal(train, [0, 1, 2, 3, 4])
14911568
assert_array_equal(test, [7, 8, 9])
14921569

1570+
# Test alone with groups
1571+
splits = TimeSeriesSplit(n_splits=2, gap=2).split(X, groups=groups)
1572+
1573+
train, test = next(splits)
1574+
assert_array_equal(train, [0])
1575+
assert_array_equal(test, [5, 6, 7])
1576+
1577+
train, test = next(splits)
1578+
assert_array_equal(train, [0, 1, 2, 3, 4])
1579+
assert_array_equal(test, [8, 9])
1580+
14931581
# Test with max_train_size
14941582
splits = TimeSeriesSplit(n_splits=3, gap=2, max_train_size=2).split(X)
14951583

@@ -1505,6 +1593,22 @@ def test_time_series_gap():
15051593
assert_array_equal(train, [4, 5])
15061594
assert_array_equal(test, [8, 9])
15071595

1596+
# Test with max_train_size and groups
1597+
splits = TimeSeriesSplit(n_splits=3, gap=2,
1598+
max_train_size=2).split(X, groups=groups)
1599+
1600+
train, test = next(splits)
1601+
assert_array_equal(train, [0, 1])
1602+
assert_array_equal(test, [7])
1603+
1604+
train, test = next(splits)
1605+
assert_array_equal(train, [1, 2, 3, 4])
1606+
assert_array_equal(test, [8])
1607+
1608+
train, test = next(splits)
1609+
assert_array_equal(train, [2, 3, 4, 5, 6])
1610+
assert_array_equal(test, [9])
1611+
15081612
# Test with test_size
15091613
splits = TimeSeriesSplit(n_splits=2, gap=2,
15101614
max_train_size=4, test_size=2).split(X)
@@ -1517,6 +1621,18 @@ def test_time_series_gap():
15171621
assert_array_equal(train, [2, 3, 4, 5])
15181622
assert_array_equal(test, [8, 9])
15191623

1624+
# Test with test_size and groups
1625+
splits = TimeSeriesSplit(n_splits=2, gap=2, max_train_size=4, test_size=2)\
1626+
.split(X, groups=groups)
1627+
1628+
train, test = next(splits)
1629+
assert_array_equal(train, [0])
1630+
assert_array_equal(test, [5, 6, 7])
1631+
1632+
train, test = next(splits)
1633+
assert_array_equal(train, [0, 1, 2, 3, 4])
1634+
assert_array_equal(test, [8, 9])
1635+
15201636
# Test with additional test_size
15211637
splits = TimeSeriesSplit(n_splits=2, gap=2, test_size=3).split(X)
15221638

@@ -1532,6 +1648,9 @@ def test_time_series_gap():
15321648
with pytest.raises(ValueError, match="Too many splits.*and gap"):
15331649
splits = TimeSeriesSplit(n_splits=4, gap=2).split(X)
15341650
next(splits)
1651+
with pytest.raises(ValueError, match="Too many splits.*and gap"):
1652+
splits = TimeSeriesSplit(n_splits=5, gap=2).split(X, groups=groups)
1653+
next(splits)
15351654

15361655

15371656
def test_nested_cv():

0 commit comments

Comments
 (0)
0