8000 FIX Improve error message when RepeatedStratifiedKFold.split is called without a y argument by Anurag-Varma · Pull Request #29402 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
< 8000 div class="d-flex flex-column mt-2">
3 changes: 3 additions & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ Changelog
estimator without re-fitting it.
:pr:`29067` by :user:`Guillaume Lemaitre <glemaitre>`.

- |Fix| Improve error message when :func:`model_selection.RepeatedStratifiedKFold.split` is called without a `y` argument
:pr:`29402` by :user:`Anurag Varma <Anurag-Varma>`.

:mod:`sklearn.neighbors`
........................

Expand Down
37 changes: 37 additions & 0 deletions sklearn/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -1769,6 +1769,43 @@ def __init__(self, *, n_splits=5, n_repeats=10, random_state=None):
n_splits=n_splits,
)

def split(self, X, y, groups=None):
"""Generate indices to split data into training and test set.

Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data, where `n_samples` is the number of samples
and `n_features` is the number of features.

Note that providing ``y`` is sufficient to generate the splits and
hence ``np.zeros(n_samples)`` may be used as a placeholder for
``X`` instead of actual training data.

y : array-like of shape (n_samples,)
The target variable for supervised learning problems.
Stratification is done based on the y labels.

groups : object
Always ignored, exists for compatibility.

Yields
------
train : ndarray
The training set indices for that split.

test : ndarray
The testing set indices for that split.

Notes
-----
Randomized CV splitters may return different results for each call of
split. You can make the results identical by setting `random_state`
to an integer.
"""
y = check_array(y, input_name="y", ensure_2d=False, dtype=None)
return super().split(X, y, groups=groups)


class BaseShuffleSplit(_MetadataRequester, metaclass=ABCMeta):
"""Base class for *ShuffleSplit.
Expand Down
15 changes: 15 additions & 0 deletions sklearn/model_selection/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@

ALL_SPLITTERS = NO_GROUP_SPLITTERS + GROUP_SPLITTERS # type: ignore

SPLITTERS_REQUIRING_TARGET = [
StratifiedKFold(),
StratifiedShuffleSplit(),
RepeatedStratifiedKFold(),
]

X = np.ones(10)
y = np.arange(10) // 2
test_groups = (
Expand Down Expand Up @@ -2054,3 +2060,12 @@ def test_no_group_splitters_warns_with_groups(cv):

with pytest.warns(UserWarning, match=msg):
cv.split(X, y, groups=groups)


@pytest.mark.parametrize(
"cv", SPLITTERS_REQUIRING_TARGET, ids=[str(cv) for cv in SPLITTERS_REQUIRING_TARGET]
)
def test_stratified_splitter_without_y(cv):
msg = "missing 1 required positional argument: 'y'"
with pytest.raises(TypeError, match=msg):
cv.split(X)
0