8000 feat(cv): enable group split in TimeSeriesCV · scikit-learn/scikit-learn@91e9d0a · GitHub
[go: up one dir, main page]

Skip to content

Commit 91e9d0a

Browse files
committed
feat(cv): enable group split in TimeSeriesCV
1 parent e9c6fca commit 91e9d0a

File tree

2 files changed

+161
-14
lines changed

2 files changed

+161
-14
lines changed

sklearn/model_selection/_split.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -761,17 +761,19 @@ class TimeSeriesSplit(_BaseKFold):
761761
762762
max_train_size : int, default=None
763763
Maximum size for a single training set.
764+
or maximum number of data groups when group is supplied
764765
765766
test_size : int, default=None
766767
Used to limit the size of the test set. Defaults to
767768
``n_samples // (n_splits + 1)``, which is the maximum allowed value
768-
with ``gap=0``.
769+
with ``gap=0``
770+
or numer of data groups when group is supplied
769771
770772
.. versionadded:: 0.24
771773
772774
gap : int, default=0
773-
Number of samples to exclude from the end of each train set before
774-
the test set.
775+
Number of samples or groups to exclude from the end of each train set
776+
before the test set.
775777
776778
.. versionadded:: 0.24
777779
@@ -845,8 +847,9 @@ def split(self, X, y=None, groups=None):
845847
y : array-like of shape (n_samples,)
846848
Always ignored, exists for compatibility.
847849
848-
groups : array-like of shape (n_samples,)
849-
Always ignored, exists for compatibility.
850+
groups : array-like of shape (n_samples,), default=None
851+
Group labels for the samples used while splitting the dataset into
852+
train/test set.
850853
851854
Yields
852855
------
@@ -856,7 +859,18 @@ def split(self, X, y=None, groups=None):
856859
test : ndarray
857860
The testing set indices for that split.
858861
"""
859-
X, y, groups = indexable(X, y, groups)
862+
samples, y, groups = indexable(X, y, groups)
863+
864+
if groups is None:
865+
X = samples
866+
cv_type = "samples"
867+
else:
868+
_, count_index, count = np.unique(groups, return_counts=True,
869+
return_index=True)
870+
X = np.argsort(count_index)
871+
cum_count = np.concatenate(([0], np.cumsum(count[X])))
872+
cv_type = "groups"
873+
860874
n_samples = _num_samples(X)
861875
n_splits = self.n_splits
862876
n_folds = n_splits + 1
@@ -868,24 +882,38 @@ def split(self, X, y=None, groups=None):
868882
if n_folds > n_samples:
869883
raise ValueError(
870884
(f"Cannot have number of folds={n_folds} greater"
871-
f" than the number of samples={n_samples}."))
885+
f" than the number of {cv_type}={n_samples}."))
872886
if n_samples - gap - (test_size * n_splits) <= 0:
873887
raise ValueError(
874-
(f"Too many splits={n_splits} for number of samples"
888+
(f"Too many splits={n_splits} for number of {cv_type}"
875889
f"={n_samples} with test_size={test_size} and gap={gap}."))
876890

877-
indices = np.arange(n_samples)
891+
if groups is None:
892+
indices = np.arange(n_samples)
893+
else:
894+
indices = np.arange(_num_samples(samples))
878895
test_starts = range(n_samples - n_splits * test_size,
879896
n_samples, test_size)
880897

881898
for test_start in test_starts:
882899
train_end = test_start - gap
883900
if self.max_train_size and self.max_train_size < train_end:
884-
yield (indices[train_end - self.max_train_size:train_end],
885-
indices[test_start:test_start + test_size])
901+
if groups is None:
902+
yield (indices[train_end - self.max_train_size:train_end],
903+
indices[test_start:test_start + test_size])
904+
else:
905+
yield (indices[cum_count[train_end - self.max_train_size]:
906+
cum_count[train_end]],
907+
indices[cum_count[test_start]:
908+
cum_count[test_start + test_size]])
886909
else:
887-
yield (indices[:train_end],
888-
indices[test_start:test_start + test_size])
910+
if groups is None:
911+
yield (indices[:train_end],
912+
indices[test_start:test_start + test_size])
913+
else:
914+
yield (indices[:cum_count[train_end]],
915+
indices[cum_count[test_start]:
916+
cum_count[test_start + test_size]])
889917

890918

891919
class LeaveOneGroupOut(BaseCrossValidator):

sklearn/model_selection/tests/test_split.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def test_2d_y():
139139
ShuffleSplit(), StratifiedShuffleSplit(test_size=.5),
140140
GroupShuffleSplit(), LeaveOneGroupOut(),
141141
LeavePGroupsOut(n_groups=2), GroupKFold(n_splits=3),
142-
TimeSeriesSplit(), PredefinedSplit(test_fold=groups)]
142+
TimeSeriesSplit(2),
143+
PredefinedSplit(test_fold=groups)]
143144
for splitter in splitters:
144145
list(splitter.split(X, y, groups))
145146
list(splitter.split(X, y_2d, groups))
@@ -1382,6 +1383,7 @@ def test_group_kfold():
13821383

13831384
def test_time_series_cv():
13841385
X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]]
1386+
groups = np.array([1, 1, 2, 2, 3, 4, 5])
13851387

13861388
# Should fail if there are more folds than samples
13871389
assert_raises_regexp(ValueError, "Cannot have number of folds.*greater",
@@ -1411,12 +1413,37 @@ def test_time_series_cv():
14111413
assert_array_equal(train, [0, 1, 2, 3, 4])
14121414
assert_array_equal(test, [5, 6])
14131415

1416+
# ordering on toy datasets with group
1417+
splits = tscv.split(X[:-1], groups=groups[:-1])
1418+
train, test = next(splits)
1419+
assert_array_equal(train, [0, 1, 2, 3])
1420+
10000 assert_array_equal(test, [4])
1421+
1422+
train, test = next(splits)
1423+
assert_array_equal(train, [0, 1, 2, 3, 4])
1424+
assert_array_equal(test, [5])
1425+
1426+
splits = TimeSeriesSplit(2).split(X)
1427+
1428+
train, test = next(splits)
1429+
assert_array_equal(train, [0, 1, 2])
1430+
assert_array_equal(test, [3, 4])
1431+
1432+
train, test = next(splits)
1433+
assert_array_equal(train, [0, 1, 2, 3, 4])
1434+
assert_array_equal(test, [5, 6])
1435+
14141436
# Check get_n_splits returns the correct number of splits
14151437
splits = TimeSeriesSplit(2).split(X)
14161438
n_splits_actual = len(list(splits))
14171439
assert n_splits_actual == tscv.get_n_splits()
14181440
assert n_splits_actual == 2
14191441

1442+
splits = TimeSeriesSplit(2).split(X, groups=groups)
1443+
n_splits_actual = len(list(splits))
1444+
assert n_splits_actual == tscv.get_n_splits()
1445+
assert n_splits_actual == 2
1446+
14201447

14211448
def _check_time_series_max_train_size(splits, check_splits, max_train_size):
14221449
for (train, test), (check_train, check_test) in zip(splits, check_splits):
@@ -1428,21 +1455,39 @@ def _check_time_series_max_train_size(splits, check_splits, max_train_size):
14281455

14291456
def test_time_series_max_train_size():
14301457
X = np.zeros((6, 1))
1458+
groups = np.array([3, 4, 5, 1, 2, 2])
14311459
splits = TimeSeriesSplit(n_splits=3).split(X)
1460+
group_splits = TimeSeriesSplit(n_splits=3).split(X, groups=groups)
1461+
14321462
check_splits = TimeSeriesSplit(n_splits=3, max_train_size=3).split(X)
14331463
_check_time_series_max_train_size(splits, check_splits, max_train_size=3)
14341464

1465+
check_splits = TimeSeriesSplit(n_splits=3, max_train_size=3) \
1466+
.split(X, groups=groups)
1467+
_check_time_series_max_train_size(group_splits,
1468+
check_splits, max_train_size=3)
1469+
14351470
# Test for the case where the size of a fold is greater than max_train_size
14361471
check_splits = TimeSeriesSplit(n_splits=3, max_train_size=2).split(X)
14371472
_check_time_series_max_train_size(splits, check_splits, max_train_size=2)
14381473

1474+
check_splits = TimeSeriesSplit(n_splits=2, max_train_size=2) \
1475+
.split(X, groups=groups)
1476+
_check_time_series_max_train_size(group_splits,
1477+
check_splits, max_train_size=2)
1478+
14391479
# Test for the case where the size of each fold is less than max_train_size
14401480
check_splits = TimeSeriesSplit(n_splits=3, max_train_size=5).split(X)
14411481
_check_time_series_max_train_size(splits, check_splits, max_train_size=2)
14421482

1483+
check_splits = TimeSeriesSplit(n_splits=3, max_train_size=5).split(X)
1484+
_check_time_series_max_train_size(group_splits,
1485+
check_splits, max_train_size=2)
1486+
14431487

14441488
def test_time_series_test_size():
14451489
X = np.zeros((10, 1))
1490+
groups = np.array([6, 7, 1, 1, 1, 2, 2, 3, 4, 5])
14461491

14471492
# Test alone
14481493
splits = TimeSeriesSplit(n_splits=3, test_size=3).split(X)
@@ -1459,6 +1504,21 @@ def test_time_series_test_size():
14591504
assert_array_equal(train, [0, 1, 2, 3, 4, 5, 6])
14601505
assert_array_equal(test, [7, 8, 9])
14611506

1507+
# Test alone with groups
1508+
splits = TimeSeriesSplit(n_splits=3, test_size=2).split(X, groups=groups)
1509+
1510+
train, test = next(splits)
1511+
assert_array_equal(train, [0])
1512+
assert_array_equal(test, [1, 2, 3, 4])
1513+
1514+
train, test = next(splits)
1515+
assert_array_equal(train, [0, 1, 2, 3, 4])
1516+
assert_array_equal(test, [5, 6, 7])
1517+
1518+
train, test = next(splits)
1519+
assert_array_equal(train, [0, 1, 2, 3, 4, 5, 6, 7])
1520+
assert_array_equal(test, [8, 9])
1521+
14621522
# Test with max_train_size
14631523
splits = TimeSeriesSplit(n_splits=2, test_size=2,
14641524
max_train_size=4).split(X)
@@ -1471,14 +1531,31 @@ def test_time_series_test_size():
14711531
assert_array_equal(train, [4, 5, 6, 7])
14721532
assert_array_equal(test, [8, 9])
14731533

1534+
# Test with max_train_size and groups
1535+
splits = TimeSeriesSplit(n_splits=2, test_size=2,
1536+
max_train_size=2).split(X, groups=groups)
1537+
1538+
train, test = next(splits)
1539+
assert_array_equal(train, [1, 2, 3, 4])
1540+
assert_array_equal(test, [5, 6, 7])
1541+
1542+
train, test = next(splits)
1543+
assert_array_equal(train, [5, 6, 7])
1544+
assert_array_equal(test, [8, 9])
1545+
14741546
# Should fail with not enough data points for configuration
14751547
with pytest.raises(ValueError, match="Too many splits.*with test_size"):
14761548
splits = TimeSeriesSplit(n_splits=5, test_size=2).split(X)
14771549
next(splits)
1550+
with pytest.raises(ValueError, match="Too many splits.*with test_size"):
1551+
splits = TimeSeriesSplit(n_splits=5, test_size=2) \
1552+
.split(X, groups=groups)
1553+
next(splits)
14781554

14791555

14801556
def test_time_series_gap():
14811557
X = np.zeros((10, 1))
1558+
groups = np.array([6, 7, 1, 1, 1, 2, 2, 3, 4, 5])
14821559

14831560
# Test alone
14841561
splits = TimeSeriesSplit(n_splits=2, gap=2).split(X)
@@ -1491,6 +1568,17 @@ def test_time_series_gap():
14911568
assert_array_equal(train, [0, 1, 2, 3, 4])
14921569
assert_array_equal(test, [7, 8, 9])
14931570

1571+
# Test alone with groups
1572+
splits = TimeSeriesSplit(n_splits=2, gap=2).split(X, groups=groups)
1573+
1574+
train, test = next(splits)
1575+
assert_array_equal(train, [0])
1576+
assert_array_equal(test, [5, 6, 7])
1577+
1578+
train, test = next(splits)
1579+
assert_array_equal(train, [0, 1, 2, 3, 4])
1580+
assert_array_equal(test, [8, 9])
1581+
14941582
# Test with max_train_size
14951583
splits = TimeSeriesSplit(n_splits=3, gap=2, max_train_size=2).split(X)
14961584

@@ -1506,6 +1594,22 @@ def test_time_series_gap():
15061594
assert_array_equal(train, [4, 5])
15071595
assert_array_equal(test, [8, 9])
15081596

1597+
# Test with max_train_size and groups
1598+
splits = TimeSeriesSplit(n_splits=3, gap=2,
1599+
max_train_size=2).split(X, groups=groups)
1600+
1601+
train, test = next(splits)
1602+
assert_array_equal(train, [0, 1])
1603+
assert_array_equal(test, [7])
1604+
1605+
train, test = next(splits)
1606+
assert_array_equal(train, [1, 2, 3, 4])
1607+
assert_array_equal(test, [8])
1608+
1609+
train, test = next(splits)
1610+
assert_array_equal(train, [2, 3, 4, 5, 6])
1611+
assert_array_equal(test, [9])
1612+
15091613
# Test with test_size
15101614
splits = TimeSeriesSplit(n_splits=2, gap=2,
15111615
max_train_size=4, test_size=2).split(X)
@@ -1518,6 +1622,18 @@ def test_time_series_gap():
15181622
assert_array_equal(train, [2, 3, 4, 5])
15191623
assert_array_equal(test, [8, 9])
15201624

1625+
# Test with test_size and groups
1626+
splits = TimeSeriesSplit(n_splits=2, gap=2, max_train_size=4, test_size=2)\
1627+
.split(X, groups=groups)
1628+
1629+
train, test = next(splits)
1630+
assert_array_equal(train, [0])
1631+
assert_array_equal(test, [5, 6, 7])
1632+
1633+
train, test = next(splits)
1634+
assert_array_equal(train, [0, 1, 2, 3, 4])
1635+
assert_array_equal(test, [8, 9])
1636+
15211637
# Test with additional test_size
15221638
splits = TimeSeriesSplit(n_splits=2, gap=2, test_size=3).split(X)
15231639

@@ -1533,6 +1649,9 @@ def test_time_series_gap():
15331649
with pytest.raises(ValueError, match="Too many splits.*and gap"):
15341650
splits = TimeSeriesSplit(n_splits=4, gap=2).split(X)
15351651
next(splits)
1652+
with pytest.raises(ValueError, match="Too many splits.*and gap"):
1653+
splits = TimeSeriesSplit(n_splits=5, gap=2).split(X, groups=groups)
1654+
next(splits)
15361655

15371656

15381657
def test_nested_cv():

0 commit comments

Comments
 (0)
0