8000 DOC/TST Clarify group order in GroupKFold and LeaveOneGroupOut (#22582) · scikit-learn/scikit-learn@fbe2126 · GitHub
[go: up one dir, main page]

Skip to content

Commit fbe2126

Browse files
SamAdamDayglemaitre
authored andcommitted
DOC/TST Clarify group order in GroupKFold and LeaveOneGroupOut (#22582)
1 parent e8199d4 commit fbe2126

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

sklearn/model_selection/_split.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,10 @@ class GroupKFold(_BaseKFold):
469469
.. versionchanged:: 0.22
470470
``n_splits`` default value changed from 3 to 5.
471471
472+
Notes
473+
-----
474+
Groups appear in an arbitrary order throughout the folds.
475+
472476
Examples
473477
--------
474478
>>> import numpy as np
@@ -1110,6 +1114,12 @@ class LeaveOneGroupOut(BaseCrossValidator):
11101114
11111115
Read more in the :ref:`User Guide <leave_one_group_out>`.
11121116
1117+
Notes
1118+
-----
1119+
Splits are ordered according to the index of the group left out. The first
1120+
split has training set consting of the group whose index in `groups` is
1121+
lowest, and so on.
1122+
11131123
Examples
11141124
--------
11151125
>>> import numpy as np
@@ -1137,7 +1147,6 @@ class LeaveOneGroupOut(BaseCrossValidator):
11371147
[[1 2]
11381148
[3 4]] [[5 6]
11391149
[7 8]] [1 2] [1 2]
1140-
11411150
"""
11421151

11431152
def _iter_test_masks(self, X, y, groups):

sklearn/model_selection/tests/test_split.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,26 @@ def test_leave_group_out_changing_groups():
10411041
assert 3 == LeaveOneGroupOut().get_n_splits(X, y=X, groups=groups)
10421042

10431043

1044+
def test_leave_group_out_order_dependence():
1045+
# Check that LeaveOneGroupOut orders the splits according to the index
1046+
# of the group left out.
1047+
groups = np.array([2, 2, 0, 0, 1, 1])
1048+
X = np.ones(len(groups))
1049+
1050+
splits = iter(LeaveOneGroupOut().split(X, groups=groups))
1051+
1052+
expected_indices = [
1053+
([0, 1, 4, 5], [2, 3]),
1054+
([0, 1, 2, 3], [4, 5]),
1055+
([2, 3, 4, 5], [0, 1]),
1056+
]
1057+
1058+
for expected_train, expected_test in expected_indices:
1059+
train, test = next(splits)
1060+
assert_array_equal(train, expected_train)
1061+
assert_array_equal(test, expected_test)
1062+
1063+
10441064
def test_leave_one_p_group_out_error_on_fewer_number_of_groups():
10451065
X = y = groups = np.ones(0)
10461066
msg = re.escape("Found array with 0 sample(s)")

0 commit comments

Comments
 (0)
0