@@ -138,7 +138,8 @@ def test_2d_y():
138
138
ShuffleSplit (), StratifiedShuffleSplit (test_size = .5 ),
139
139
GroupShuffleSplit (), LeaveOneGroupOut (),
140
140
LeavePGroupsOut (n_groups = 2 ), GroupKFold (n_splits = 3 ),
141
- TimeSeriesSplit (), PredefinedSplit (test_fold = groups )]
141
+ TimeSeriesSplit (2 ),
142
+ PredefinedSplit (test_fold = groups )]
142
143
for splitter in splitters :
143
144
list (splitter .split (X , y , groups ))
144
145
list (splitter .split (X , y_2d , groups ))
@@ -1381,6 +1382,7 @@ def test_group_kfold():
1381
1382
1382
1383
def test_time_series_cv ():
1383
1384
X = [[1 , 2 ], [3 , 4 ], [5 , 6 ], [7 , 8 ], [9 , 10 ], [11 , 12 ], [13 , 14 ]]
1385
+ groups = np .array ([1 , 1 , 2 , 2 , 3 , 4 , 5 ])
1384
1386
1385
1387
# Should fail if there are more folds than samples
1386
1388
assert_raises_regexp (ValueError , "Cannot have number of folds.*greater" ,
@@ -1410,12 +1412,37 @@ def test_time_series_cv():
1410
1412
assert_array_equal (train , [0 , 1 , 2 , 3 , 4 ])
1411
1413
assert_array_equal (test , [5 , 6 ])
1412
1414
1415
+ # ordering on toy datasets with group
1416
+ splits = tscv .split (X [:- 1 ], groups = groups [:- 1 ])
1417
+ train , test = next (splits )
1418
+ assert_array_equal (train , [0 , 1 , 2 , 3 ])
1419
+ assert_array_equal (test , [4 ])
1420
+
1421
+ train , test = next (splits )
1422
+ assert_array_equal (train , [0 , 1 , 2 , 3 , 4 ])
1423
+ assert_array_equal (test , [5 ])
1424
+
1425
+ splits = TimeSeriesSplit (2 ).split (X )
1426
+
1427
+ train , test = next (splits )
1428
+ assert_array_equal (train , [0 , 1 , 2 ])
1429
+ assert_array_equal (test , [3 , 4 ])
1430
+
1431
+ train , test = next (splits )
1432
+ assert_array_equal (train , [0 , 1 , 2 , 3 , 4 ])
1433
+ assert_array_equal (test , [5 , 6 ])
1434
+
1413
1435
# Check get_n_splits returns the correct number of splits
1414
1436
splits = TimeSeriesSplit (2 ).split (X )
1415
1437
n_splits_actual = len (list (splits ))
1416
1438
assert n_splits_actual == tscv .get_n_splits ()
1417
1439
assert n_splits_actual == 2
1418
1440
1441
+ splits = TimeSeriesSplit (2 ).split (X , groups = groups )
1442
+ n_splits_actual = len (list (splits ))
1443
+ assert n_splits_actual == tscv .get_n_splits ()
1444
+ assert n_splits_actual == 2
1445
+
1419
1446
1420
1447
def _check_time_series_max_train_size (splits , check_splits , max_train_size ):
1421
1448
for (train , test ), (check_train , check_test ) in zip (splits , check_splits ):
@@ -1427,21 +1454,39 @@ def _check_time_series_max_train_size(splits, check_splits, max_train_size):
1427
1454
1428
1455
def test_time_series_max_train_size ():
1429
1456
X = np .zeros ((6 , 1 ))
1457
+ groups = np .array ([3 , 4 , 5 , 1 , 2 , 2 ])
1430
1458
splits = TimeSeriesSplit (n_splits = 3 ).split (X )
1459
+ group_splits = TimeSeriesSplit (n_splits = 3 ).split (X , groups = groups )
1460
+
1431
1461
check_splits = TimeSeriesSplit (n_splits = 3 , max_train_size = 3 ).split (X )
1432
1462
_check_time_series_max_train_size (splits , check_splits , max_train_size = 3 )
1433
1463
1464
+ check_splits = TimeSeriesSplit (n_splits = 3 , max_train_size = 3 ) \
1465
+ .split (X , groups = groups )
1466
+ _check_time_series_max_train_size (group_splits ,
1467
+ check_splits , max_train_size = 3 )
1468
+
1434
1469
# Test for the case where the size of a fold is greater than max_train_size
1435
1470
check_splits = TimeSeriesSplit (n_splits = 3 , max_train_size = 2 ).split (X )
1436
1471
_check_time_series_max_train_size (splits , check_splits , max_train_size = 2 )
1437
1472
1473
+ check_splits = TimeSeriesSplit (n_splits = 2 , max_train_size = 2 ) \
1474
+ .split (X , groups = groups )
1475
+ _check_time_series_max_train_size (group_splits ,
1476
+ check_splits , max_train_size = 2 )
1477
+
1438
1478
# Test for the case where the size of each fold is less than max_train_size
1439
1479
check_splits = TimeSeriesSplit (n_splits = 3 , max_train_size = 5 ).split (X )
1440
1480
_check_time_series_max_train_size (splits , check_splits , max_train_size = 2 )
1441
1481
1482
+ check_splits = TimeSeriesSplit (n_splits = 3 , max_train_size = 5 ).split (X )
1483
+ _check_time_series_max_train_size (group_splits ,
1484
+ check_splits , max_train_size = 2 )
1485
+
1442
1486
1443
1487
def test_time_series_test_size ():
1444
1488
X = np .zeros ((10 , 1 ))
1489
+ groups = np .array ([6 , 7 , 1 , 1 , 1 , 2 , 2 , 3 , 4 , 5 ])
1445
1490
1446
1491
# Test alone
1447
1492
splits = TimeSeriesSplit (n_splits = 3 , test_size = 3 ).split (X )
@@ -1458,6 +1503,21 @@ def test_time_series_test_size():
1458
1503
assert_array_equal (train , [0 , 1 , 2 , 3 , 4 , 5 , 6 ])
1459
1504
assert_array_equal (test , [7 , 8 , 9 ])
1460
1505
1506
+ # Test alone with groups
1507
+ splits = TimeSeriesSplit (n_splits = 3 , test_size = 2 ).split (X , groups = groups )
1508
+
1509
+ train , test = next (splits )
1510
+ assert_array_equal (train , [0 ])
1511
+ assert_array_equal (test , [1 , 2 , 3 , 4 ])
1512
+
1513
+ train , test = next (splits )
1514
+ assert_array_equal (train , [0 , 1 , 2 , 3 , 4 ])
1515
+ assert_array_equal (test , [5 , 6 , 7 ])
1516
+
1517
+ train , test = next (splits )
1518
+ assert_array_equal (train , [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ])
1519
+ assert_array_equal (test , [8 , 9 ])
1520
+
1461
1521
# Test with max_train_size
1462
1522
splits = TimeSeriesSplit (n_splits = 2 , test_size = 2 ,
1463
1523
max_train_size = 4 ).split (X )
@@ -1470,14 +1530,31 @@ def test_time_series_test_size():
1470
1530
assert_array_equal (train , [4 , 5 , 6 , 7 ])
1471
1531
assert_array_equal (test , [8 , 9 ])
1472
1532
1533
+ # Test with max_train_size and groups
1534
+ splits = TimeSeriesSplit (n_splits = 2 , test_size = 2 ,
1535
+ max_train_size = 2 ).split (X , groups = groups )
1536
+
1537
+ train , test = next (splits )
1538
+ assert_array_equal (train , [1 , 2 , 3 , 4 ])
1539
+ assert_array_equal (test , [5 , 6 , 7 ])
1540
+
1541
+ train , test = next (splits )
1542
+ assert_array_equal (train , [5 , 6 , 7 ])
1543
+ assert_array_equal (test , [8 , 9 ])
1544
+
1473
1545
# Should fail with not enough data points for configuration
1474
1546
with pytest .raises (ValueError , match = "Too many splits.*with test_size" ):
1475
1547
splits = TimeSeriesSplit (n_splits = 5 , test_size = 2 ).split (X )
1476
1548
next (splits )
1549
+ with pytest .raises (ValueError , match = "Too many splits.*with test_size" ):
1550
+ splits = TimeSeriesSplit (n_splits = 5 , test_size = 2 ) \
1551
+ .split (X , groups = groups )
1552
+ next (splits )
1477
1553
1478
1554
1479
1555
def test_time_series_gap ():
1480
1556
X = np .zeros ((10 , 1 ))
1557
+ groups = np .array ([6 , 7 , 1 , 1 , 1 , 2 , 2 , 3 , 4 , 5 ])
1481
1558
1482
1559
# Test alone
1483
1560
splits = TimeSeriesSplit (n_splits = 2 , gap = 2 ).split (X )
@@ -1490,6 +1567,17 @@ def test_time_series_gap():
1490
1567
assert_array_equal (train , [0 , 1 , 2 , 3 , 4 ])
1491
1568
assert_array_equal (test , [7 , 8 , 9 ])
1492
1569
1570
+ # Test alone with groups
1571
+ splits = TimeSeriesSplit (n_splits = 2 , gap = 2 ).split (X , groups = groups )
1572
+
1573
+ train , test = next (splits )
1574
+ assert_array_equal (train , [0 ])
1575
+ assert_array_equal (test , [5 , 6 , 7 ])
1576
+
1577
+ train , test = next (splits )
1578
+ assert_array_equal (train , [0 , 1 , 2 , 3 , 4 ])
1579
+ assert_array_equal (test , [8 , 9 ])
1580
+
1493
1581
# Test with max_train_size
1494
1582
splits = TimeSeriesSplit (n_splits = 3 , gap = 2 , max_train_size = 2 ).split (X )
1495
1583
@@ -1505,6 +1593,22 @@ def test_time_series_gap():
1505
1593
assert_array_equal (train , [4 , 5 ])
1506
1594
assert_array_equal (test , [8 , 9 ])
1507
1595
1596
+ # Test with max_train_size and groups
1597
+ splits = TimeSeriesSplit (n_splits = 3 , gap = 2 ,
1598
+ max_train_size = 2 ).split (X , groups = groups )
1599
+
1600
+ train , test = next (splits )
1601
+ assert_array_equal (train , [0 , 1 ])
1602
+ assert_array_equal (test , [7 ])
1603
+
1604
+ train , test = next (splits )
1605
+ assert_array_equal (train , [1 , 2 , 3 , 4 ])
1606
+ assert_array_equal (test , [8 ])
1607
+
1608
+ train , test = next (splits )
1609
+ assert_array_equal (train , [2 , 3 , 4 , 5 , 6 ])
1610
+ assert_array_equal (test , [9 ])
1611
+
1508
1612
# Test with test_size
1509
1613
splits = TimeSeriesSplit (n_splits = 2 , gap = 2 ,
1510
1614
max_train_size = 4 , test_size = 2 ).split (X )
@@ -1517,6 +1621,18 @@ def test_time_series_gap():
1517
1621
assert_array_equal (train , [2 , 3 , 4 , 5 ])
1518
1622
assert_array_equal (test , [8 , 9 ])
1519
1623
1624
+ # Test with test_size and groups
1625
+ splits = TimeSeriesSplit (n_splits = 2 , gap = 2 , max_train_size = 4 , test_size = 2 )\
1626
+ .split (X , groups = groups )
1627
+
1628
+ train , test = next (splits )
1629
+ assert_array_equal (train , [0 ])
1630
+ assert_array_equal (test , [5 , 6 , 7 ])
1631
+
1632
+ train , test = next (splits )
1633
+ assert_array_equal (train , [0 , 1 , 2 , 3 , 4 ])
1634
+ assert_array_equal (test , [8 , 9 ])
1635
+
1520
1636
# Test with additional test_size
1521
1637
splits = TimeSeriesSplit (n_splits = 2 , gap = 2 , test_size = 3 ).split (X )
1522
1638
@@ -1532,6 +1648,9 @@ def test_time_series_gap():
1532
1648
with pytest .raises (ValueError , match = "Too many splits.*and gap" ):
1533
1649
splits = TimeSeriesSplit (n_splits = 4 , gap = 2 ).split (X )
1534
1650
next (splits )
1651
+ with pytest .raises (ValueError , match = "Too many splits.*and gap" ):
1652
+ splits = TimeSeriesSplit (n_splits = 5 , gap = 2 ).split (X , groups = groups )
1653
+ next (splits )
1535
1654
1536
1655
1537
1656
def test_nested_cv ():
0 commit comments