@@ -139,7 +139,8 @@ def test_2d_y():
139
139
ShuffleSplit (), StratifiedShuffleSplit (test_size = .5 ),
140
140
GroupShuffleSplit (), LeaveOneGroupOut (),
141
141
LeavePGroupsOut (n_groups = 2 ), GroupKFold (n_splits = 3 ),
142
- TimeSeriesSplit (), PredefinedSplit (test_fold = groups )]
142
+ TimeSeriesSplit (2 ),
143
+ PredefinedSplit (test_fold = groups )]
143
144
for splitter in splitters :
144
145
list (splitter .split (X , y , groups ))
145
146
list (splitter .split (X , y_2d , groups ))
@@ -1382,6 +1383,7 @@ def test_group_kfold():
1382
1383
1383
1384
def test_time_series_cv ():
1384
1385
X = [[1 , 2 ], [3 , 4 ], [5 , 6 ], [7 , 8 ], [9 , 10 ], [11 , 12 ], [13 , 14 ]]
1386
+ groups = np .array ([1 , 1 , 2 , 2 , 3 , 4 , 5 ])
1385
1387
1386
1388
# Should fail if there are more folds than samples
1387
1389
assert_raises_regexp (ValueError , "Cannot have number of folds.*greater" ,
@@ -1411,12 +1413,37 @@ def test_time_series_cv():
1411
1413
assert_array_equal (train , [0 , 1 , 2 , 3 , 4 ])
1412
1414
assert_array_equal (test , [5 , 6 ])
1413
1415
1416
+ # ordering on toy datasets with group
1417
+ splits = tscv .split (X [:- 1 ], groups = groups [:- 1 ])
1418
+ train , test = next (splits )
1419
+ assert_array_equal (train , [0 , 1 , 2 , 3 ])
1420
+
10000
assert_array_equal (test , [4 ])
1421
+
1422
+ train , test = next (splits )
1423
+ assert_array_equal (train , [0 , 1 , 2 , 3 , 4 ])
1424
+ assert_array_equal (test , [5 ])
1425
+
1426
+ splits = TimeSeriesSplit (2 ).split (X )
1427
+
1428
+ train , test = next (splits )
1429
+ assert_array_equal (train , [0 , 1 , 2 ])
1430
+ assert_array_equal (test , [3 , 4 ])
1431
+
1432
+ train , test = next (splits )
1433
+ assert_array_equal (train , [0 , 1 , 2 , 3 , 4 ])
1434
+ assert_array_equal (test , [5 , 6 ])
1435
+
1414
1436
# Check get_n_splits returns the correct number of splits
1415
1437
splits = TimeSeriesSplit (2 ).split (X )
1416
1438
n_splits_actual = len (list (splits ))
1417
1439
assert n_splits_actual == tscv .get_n_splits ()
1418
1440
assert n_splits_actual == 2
1419
1441
1442
+ splits = TimeSeriesSplit (2 ).split (X , groups = groups )
1443
+ n_splits_actual = len (list (splits ))
1444
+ assert n_splits_actual == tscv .get_n_splits ()
1445
+ assert n_splits_actual == 2
1446
+
1420
1447
1421
1448
def _check_time_series_max_train_size (splits , check_splits , max_train_size ):
1422
1449
for (train , test ), (check_train , check_test ) in zip (splits , check_splits ):
@@ -1428,21 +1455,39 @@ def _check_time_series_max_train_size(splits, check_splits, max_train_size):
1428
1455
1429
1456
def test_time_series_max_train_size ():
1430
1457
X = np .zeros ((6 , 1 ))
1458
+ groups = np .array ([3 , 4 , 5 , 1 , 2 , 2 ])
1431
1459
splits = TimeSeriesSplit (n_splits = 3 ).split (X )
1460
+ group_splits = TimeSeriesSplit (n_splits = 3 ).split (X , groups = groups )
1461
+
1432
1462
check_splits = TimeSeriesSplit (n_splits = 3 , max_train_size = 3 ).split (X )
1433
1463
_check_time_series_max_train_size (splits , check_splits , max_train_size = 3 )
1434
1464
1465
+ check_splits = TimeSeriesSplit (n_splits = 3 , max_train_size = 3 ) \
1466
+ .split (X , groups = groups )
1467
+ _check_time_series_max_train_size (group_splits ,
1468
+ check_splits , max_train_size = 3 )
1469
+
1435
1470
# Test for the case where the size of a fold is greater than max_train_size
1436
1471
check_splits = TimeSeriesSplit (n_splits = 3 , max_train_size = 2 ).split (X )
1437
1472
_check_time_series_max_train_size (splits , check_splits , max_train_size = 2 )
1438
1473
1474
+ check_splits = TimeSeriesSplit (n_splits = 2 , max_train_size = 2 ) \
1475
+ .split (X , groups = groups )
1476
+ _check_time_series_max_train_size (group_splits ,
1477
+ check_splits , max_train_size = 2 )
1478
+
1439
1479
# Test for the case where the size of each fold is less than max_train_size
1440
1480
check_splits = TimeSeriesSplit (n_splits = 3 , max_train_size = 5 ).split (X )
1441
1481
_check_time_series_max_train_size (splits , check_splits , max_train_size = 2 )
1442
1482
1483
+ check_splits = TimeSeriesSplit (n_splits = 3 , max_train_size = 5 ).split (X )
1484
+ _check_time_series_max_train_size (group_splits ,
1485
+ check_splits , max_train_size = 2 )
1486
+
1443
1487
1444
1488
def test_time_series_test_size ():
1445
1489
X = np .zeros ((10 , 1 ))
1490
+ groups = np .array ([6 , 7 , 1 , 1 , 1 , 2 , 2 , 3 , 4 , 5 ])
1446
1491
1447
1492
# Test alone
1448
1493
splits = TimeSeriesSplit (n_splits = 3 , test_size = 3 ).split (X )
@@ -1459,6 +1504,21 @@ def test_time_series_test_size():
1459
1504
assert_array_equal (train , [0 , 1 , 2 , 3 , 4 , 5 , 6 ])
1460
1505
assert_array_equal (test , [7 , 8 , 9 ])
1461
1506
1507
+ # Test alone with groups
1508
+ splits = TimeSeriesSplit (n_splits = 3 , test_size = 2 ).split (X , groups = groups )
1509
+
1510
+ train , test = next (splits )
1511
+ assert_array_equal (train , [0 ])
1512
+ assert_array_equal (test , [1 , 2 , 3 , 4 ])
1513
+
1514
+ train , test = next (splits )
1515
+ assert_array_equal (train , [0 , 1 , 2 , 3 , 4 ])
1516
+ assert_array_equal (test , [5 , 6 , 7 ])
1517
+
1518
+ train , test = next (splits )
1519
+ assert_array_equal (train , [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ])
1520
+ assert_array_equal (test , [8 , 9 ])
1521
+
1462
1522
# Test with max_train_size
1463
1523
splits = TimeSeriesSplit (n_splits = 2 , test_size = 2 ,
1464
1524
max_train_size = 4 ).split (X )
@@ -1471,14 +1531,31 @@ def test_time_series_test_size():
1471
1531
assert_array_equal (train , [4 , 5 , 6 , 7 ])
1472
1532
assert_array_equal (test , [8 , 9 ])
1473
1533
1534
+ # Test with max_train_size and groups
1535
+ splits = TimeSeriesSplit (n_splits = 2 , test_size = 2 ,
1536
+ max_train_size = 2 ).split (X , groups = groups )
1537
+
1538
+ train , test = next (splits )
1539
+ assert_array_equal (train , [1 , 2 , 3 , 4 ])
1540
+ assert_array_equal (test , [5 , 6 , 7 ])
1541
+
1542
+ train , test = next (splits )
1543
+ assert_array_equal (train , [5 , 6 , 7 ])
1544
+ assert_array_equal (test , [8 , 9 ])
1545
+
1474
1546
# Should fail with not enough data points for configuration
1475
1547
with pytest .raises (ValueError , match = "Too many splits.*with test_size" ):
1476
1548
splits = TimeSeriesSplit (n_splits = 5 , test_size = 2 ).split (X )
1477
1549
next (splits )
1550
+ with pytest .raises (ValueError , match = "Too many splits.*with test_size" ):
1551
+ splits = TimeSeriesSplit (n_splits = 5 , test_size = 2 ) \
1552
+ .split (X , groups = groups )
1553
+ next (splits )
1478
1554
1479
1555
1480
1556
def test_time_series_gap ():
1481
1557
X = np .zeros ((10 , 1 ))
1558
+ groups = np .array ([6 , 7 , 1 , 1 , 1 , 2 , 2 , 3 , 4 , 5 ])
1482
1559
1483
1560
# Test alone
1484
1561
splits = TimeSeriesSplit (n_splits = 2 , gap = 2 ).split (X )
@@ -1491,6 +1568,17 @@ def test_time_series_gap():
1491
1568
assert_array_equal (train , [0 , 1 , 2 , 3 , 4 ])
1492
1569
assert_array_equal (test , [7 , 8 , 9 ])
1493
1570
1571
+ # Test alone with groups
1572
+ splits = TimeSeriesSplit (n_splits = 2 , gap = 2 ).split (X , groups = groups )
1573
+
1574
+ train , test = next (splits )
1575
+ assert_array_equal (train , [0 ])
1576
+ assert_array_equal (test , [5 , 6 , 7 ])
1577
+
1578
+ train , test = next (splits )
1579
+ assert_array_equal (train , [0 , 1 , 2 , 3 , 4 ])
1580
+ assert_array_equal (test , [8 , 9 ])
1581
+
1494
1582
# Test with max_train_size
1495
1583
splits = TimeSeriesSplit (n_splits = 3 , gap = 2 , max_train_size = 2 ).split (X )
1496
1584
@@ -1506,6 +1594,22 @@ def test_time_series_gap():
1506
1594
assert_array_equal (train , [4 , 5 ])
1507
1595
assert_array_equal (test , [8 , 9 ])
1508
1596
1597
+ # Test with max_train_size and groups
1598
+ splits = TimeSeriesSplit (n_splits = 3 , gap = 2 ,
1599
+ max_train_size = 2 ).split (X , groups = groups )
1600
+
1601
+ train , test = next (splits )
1602
+ assert_array_equal (train , [0 , 1 ])
1603
+ assert_array_equal (test , [7 ])
1604
+
1605
+ train , test = next (splits )
1606
+ assert_array_equal (train , [1 , 2 , 3 , 4 ])
1607
+ assert_array_equal (test , [8 ])
1608
+
1609
+ train , test = next (splits )
1610
+ assert_array_equal (train , [2 , 3 , 4 , 5 , 6 ])
1611
+ assert_array_equal (test , [9 ])
1612
+
1509
1613
# Test with test_size
1510
1614
splits = TimeSeriesSplit (n_splits = 2 , gap = 2 ,
1511
1615
max_train_size = 4 , test_size = 2 ).split (X )
@@ -1518,6 +1622,18 @@ def test_time_series_gap():
1518
1622
assert_array_equal (train , [2 , 3 , 4 , 5 ])
1519
1623
assert_array_equal (test , [8 , 9 ])
1520
1624
1625
+ # Test with test_size and groups
1626
+ splits = TimeSeriesSplit (n_splits = 2 , gap = 2 , max_train_size = 4 , test_size = 2 )\
1627
+ .split (X , groups = groups )
1628
+
1629
+ train , test = next (splits )
1630
+ assert_array_equal (train , [0 ])
1631
+ assert_array_equal (test , [5 , 6 , 7 ])
1632
+
1633
+ train , test = next (splits )
1634
+ assert_array_equal (train , [0 , 1 , 2 , 3 , 4 ])
1635
+ assert_array_equal (test , [8 , 9 ])
1636
+
1521
1637
# Test with additional test_size
1522
1638
splits = TimeSeriesSplit (n_splits = 2 , gap = 2 , test_size = 3 ).split (X )
1523
1639
@@ -1533,6 +1649,9 @@ def test_time_series_gap():
1533
1649
with pytest .raises (ValueError , match = "Too many splits.*and gap" ):
1534
1650
splits = TimeSeriesSplit (n_splits = 4 , gap = 2 ).split (X )
1535
1651
next (splits )
1652
+ with pytest .raises (ValueError , match = "Too many splits.*and gap" ):
1653
+ splits = TimeSeriesSplit (n_splits = 5 , gap = 2 ).split (X , groups = groups )
1654
+ next (splits )
1536
1655
1537
1656
1538
1657
def test_nested_cv ():
0 commit comments