|
1 | 1 | # Authors: The scikit-learn developers
|
2 | 2 | # SPDX-License-Identifier: BSD-3-Clause
|
3 | 3 |
|
4 |
| -from numbers import Integral, Real |
| 4 | +from numbers import Real |
5 | 5 |
|
6 | 6 | import numpy as np
|
7 | 7 |
|
@@ -92,14 +92,24 @@ class TargetEncoder(OneToOneFeatureMixin, _BaseEncoder):
|
92 | 92 | more weight on the global target mean.
|
93 | 93 | If `"auto"`, then `smooth` is set to an empirical Bayes estimate.
|
94 | 94 |
|
95 |
| - cv : int or cross-validation generator, default=5 |
96 |
| - Determines the number of folds in the :term:`cross fitting` strategy used in |
97 |
| - :meth:`fit_transform`. For classification targets, `StratifiedKFold` is used |
98 |
| - and for continuous targets, `KFold` is used. |
| 95 | + cv : int, cross-validation generator or an iterable, default=None |
| 96 | + Determines the cross-validation splitting strategy. |
| 97 | + Possible inputs for cv are: |
99 | 98 |
|
100 |
| - If an integer is provided, it is the number of folds. |
101 |
| - If a cross-validation generator is provided, it should be compatible with |
102 |
| - scikit-learn's cross-validation interface. |
| 99 | + - None, to use the default 5-fold cross validation, |
| 100 | + - integer, to specify the number of folds in a `(Stratified)KFold`, |
| 101 | + - :term:`CV splitter`, |
| 102 | + - An iterable yielding (train, test) splits as arrays of indices. |
| 103 | +
|
| 104 | + For integer/None inputs, if the estimator is a classifier and ``y`` is |
| 105 | + either binary or multiclass, |
| 106 | + :class:`~sklearn.model_selection.StratifiedKFold` is used. In all other |
| 107 | + cases, :class:`~sklearn.model_selection.KFold` is used. These splitters |
| 108 | + are instantiated with `shuffle=False` so the splits will be the same |
| 109 | + across calls. |
| 110 | +
|
| 111 | + Refer :ref:`User Guide <cross_validation>` for the various |
| 112 | + cross-validation strategies that can be used here. |
103 | 113 |
|
104 | 114 | shuffle : bool, default=True
|
105 | 115 | Whether to shuffle the data in :meth:`fit_transform` before splitting into
|
@@ -195,10 +205,7 @@ class TargetEncoder(OneToOneFeatureMixin, _BaseEncoder):
|
195 | 205 | "categories": [StrOptions({"auto"}), list],
|
196 | 206 | "target_type": [StrOptions({"auto", "continuous", "binary", "multiclass"})],
|
197 | 207 | "smooth": [StrOptions({"auto"}), Interval(Real, 0, None, closed="left")],
|
198 |
| - "cv": [ |
199 |
| - Interval(Integral, 2, None, closed="left"), |
200 |
| - "cv_object", |
201 |
| - ], |
| 208 | + "cv": ["cv_object"], |
202 | 209 | "shuffle": ["boolean"],
|
203 | 210 | "random_state": ["random_state"],
|
204 | 211 | }
|
|
0 commit comments