8000 modify the TargetEncoder class to accept CV splitters · scikit-learn/scikit-learn@dcd4645 · GitHub
[go: up one dir, main page]

Skip to content

Commit dcd4645

Browse files
committed
modify the TargetEncoder class to accept CV splitters
1 parent 50f0fc7 commit dcd4645

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

sklearn/preprocessing/_target_encoder.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,15 @@ class TargetEncoder(OneToOneFeatureMixin, _BaseEncoder):
9292
more weight on the global target mean.
9393
If `"auto"`, then `smooth` is set to an empirical Bayes estimate.
9494
95-
cv : int, default=5
95+
cv : int or cross-validation generator, default=5
9696
Determines the number of folds in the :term:`cross fitting` strategy used in
9797
:meth:`fit_transform`. For classification targets, `StratifiedKFold` is used
9898
and for continuous targets, `KFold` is used.
9999
100+
If an integer is provided, it is the number of folds.
101+
If a cross-validation generator is provided, it should be compatible with
102+
scikit-learn's cross-validation interface.
103+
100104
shuffle : bool, default=True
101105
Whether to shuffle the data in :meth:`fit_transform` before splitting into
102106
folds. Note that the samples within each split will not be shuffled.
@@ -191,7 +195,10 @@ class TargetEncoder(OneToOneFeatureMixin, _BaseEncoder):
191195
"categories": [StrOptions({"auto"}), list],
192196
"target_type": [StrOptions({"auto", "continuous", "binary", "multiclass"})],
193197
"smooth": [StrOptions({"auto"}), Interval(Real, 0, None, closed="left")],
194-
"cv": [Interval(Integral, 2, None, closed="left")],
198+
"cv": [
199+
Interval(Integral, 2, None, closed="left"),
200+
"cv_object",
201+
],
195202
"shuffle": ["boolean"],
196203
"random_state": ["random_state"],
197204
}

0 commit comments

Comments
 (0)
0