-
-
Notifications
You must be signed in to change notification settings - Fork 26k
[MRG] Feature: Additional TimeSeriesSplit
Functionality
#13204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
72bd6bc
5410798
f04c37d
3fed48f
266751f
1e06e43
166b9a9
dbb51b4
a0c16f6
04b2909
95a4e63
b4fa003
caa4398
a17d06a
9fdf59a
f74b5ae
bb52c26
a765033
7461e84
b9e1fe5
5d291eb
dc58df9
e1be8cf
02d17bc
0f21d45
d6797e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1440,6 +1440,100 @@ def test_time_series_max_train_size(): | |
_check_time_series_max_train_size(splits, check_splits, max_train_size=2) | ||
|
||
|
||
def test_time_series_test_size(): | ||
X = np.zeros((10, 1)) | ||
|
||
# Test alone | ||
splits = TimeSeriesSplit(n_splits=3, test_size=3).split(X) | ||
|
||
train, test = next(splits) | ||
assert_array_equal(train, [0]) | ||
assert_array_equal(test, [1, 2, 3]) | ||
|
||
train, test = next(splits) | ||
assert_array_equal(train, [0, 1, 2, 3]) | ||
assert_array_equal(test, [4, 5, 6]) | ||
|
||
train, test = next(splits) | ||
assert_array_equal(train, [0, 1, 2, 3, 4, 5, 6]) | ||
assert_array_equal(test, [7, 8, 9]) | ||
|
||
# Test with max_train_size | ||
splits = TimeSeriesSplit(n_splits=2, test_size=2, | ||
max_train_size=4).split(X) | ||
|
||
train, test = next(splits) | ||
assert_array_equal(train, [2, 3, 4, 5]) | ||
assert_array_equal(test, [6, 7]) | ||
|
||
train, test = next(splits) | ||
assert_array_equal(train, [4, 5, 6, 7]) | ||
assert_array_equal(test, [8, 9]) | ||
|
||
# Should fail with not enough data points for configuration | ||
with pytest.raises(ValueError, match="Too many splits.*with test_size"): | ||
splits = TimeSeriesSplit(n_splits=5, test_size=2).split(X) | ||
next(splits) | ||
|
||
|
||
def test_time_series_gap(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This The two error cases can be in its own test as well. This is not a blocker and can be done in a followup PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would be happy to implement this suggestion in a followup PR |
||
X = np.zeros((10, 1)) | ||
|
||
# Test alone | ||
splits = TimeSeriesSplit(n_splits=2, gap=2).split(X) | ||
|
||
train, test = next(splits) | ||
assert_array_equal(train, [0, 1]) | ||
assert_array_equal(test, [4, 5, 6]) | ||
|
||
train, test = next(splits) | ||
assert_array_equal(train, [0, 1, 2, 3, 4]) | ||
assert_array_equal(test, [7, 8, 9]) | ||
|
||
# Test with max_train_size | ||
splits = TimeSeriesSplit(n_splits=3, gap=2, max_train_size=2).split(X) | ||
|
||
train, test = next(splits) | ||
assert_array_equal(train, [0, 1]) | ||
assert_array_equal(test, [4, 5]) | ||
|
||
train, test = next(splits) | ||
assert_array_equal(train, [2, 3]) | ||
assert_array_equal(test, [6, 7]) | ||
|
||
train, test = next(splits) | ||
assert_array_equal(train, [4, 5]) | ||
assert_array_equal(test, [8, 9]) | ||
|
||
# Test with test_size | ||
splits = TimeSeriesSplit(n_splits=2, gap=2, | ||
max_train_size=4, test_size=2).split(X) | ||
kykosic marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
train, test = next(splits) | ||
assert_array_equal(train, [0, 1, 2, 3]) | ||
assert_array_equal(test, [6, 7]) | ||
|
||
train, test = next(splits) | ||
assert_array_equal(train, [2, 3, 4, 5]) | ||
assert_array_equal(test, [8, 9]) | ||
|
||
# Test with additional test_size | ||
splits = TimeSeriesSplit(n_splits=2, gap=2, test_size=3).split(X) | ||
|
||
train, test = next(splits) | ||
assert_array_equal(train, [0, 1]) | ||
assert_array_equal(test, [4, 5, 6]) | ||
|
||
train, test = next(splits) | ||
assert_array_equal(train, [0, 1, 2, 3, 4]) | ||
assert_array_equal(test, [7, 8, 9]) | ||
|
||
# Verify proper error is thrown | ||
with pytest.raises(ValueError, match="Too many splits.*and gap"): | ||
splits = TimeSeriesSplit(n_splits=4, gap=2).split(X) | ||
next(splits) | ||
|
||
|
||
def test_nested_cv(): | ||
# Test if nested cross validation works with different combinations of cv | ||
rng = np.random.RandomState(0) | ||
|
Uh oh!
There was an error while loading. Please reload this page.