8000 FEA SLEP006: Metadata routing for `validation_curve` (#29329) · scikit-learn/scikit-learn@e1cf244 · GitHub
[go: up one dir, main page]

Skip to content

Commit e1cf244

Browse files
FEA SLEP006: Metadata routing for validation_curve (#29329)
Co-authored-by: Adam Li <adam2392@gmail.com>
1 parent 64ab789 commit e1cf244

File tree

4 files changed

+143
-27
lines changed

4 files changed

+143
-27
lines changed

doc/metadata_routing.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ Meta-estimators and functions supporting metadata routing:
305305
- :func:`sklearn.model_selection.cross_val_score`
306306
- :func:`sklearn.model_selection.cross_val_predict`
307307
- :class:`sklearn.model_selection.learning_curve`
308+
- :class:`sklearn.model_selection.validation_curve`
308309
- :class:`sklearn.multiclass.OneVsOneClassifier`
309310
- :class:`sklearn.multiclass.OneVsRestClassifier`
310311
- :class:`sklearn.multiclass.OutputCodeClassifier`
@@ -323,5 +324,4 @@ Meta-estimators and tools not supporting metadata routing yet:
323324
- :class:`sklearn.feature_selection.RFECV`
324325
- :class:`sklearn.feature_selection.SequentialFeatureSelector`
325326
- :class:`sklearn.model_selection.permutation_test_score`
326-
- :class:`sklearn.model_selection.validation_curve`
327327
- :class:`sklearn.semi_supervised.SelfTrainingClassifier`

doc/whats_new/v1.6.rst

+4
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ more details.
7777
params to the underlying regressor.
7878
:pr:`29136` by :user:`Omar Salman <OmarManzoor>`.
7979

80+
- |Feature| :func:`model_selection.validation_curve` now supports metadata routing for
81+
the `fit` method of its estimator and for its underlying CV splitter and scorer.
82+
:pr:`29329` by :user:`Stefanie Senger <StefanieSenger>`.
83+
8084
Dropping official support for PyPy
8185
----------------------------------
8286

sklearn/model_selection/_validation.py

+77-6
Original file line numberDiff line numberDiff line change
@@ -1855,7 +1855,7 @@ def learning_curve(
18551855
Parameters to pass to the fit method of the estimator.
18561856
18571857
.. deprecated:: 1.6
1858-
This parameter is deprecated and will be removed in version 1.6. Use
1858+
This parameter is deprecated and will be removed in version 1.8. Use
18591859
``params`` instead.
18601860
18611861
params : dict, default=None
@@ -2221,6 +2221,7 @@ def _incremental_fit_estimator(
22212221
"verbose": ["verbose"],
22222222
"error_score": [StrOptions({"raise"}), Real],
22232223
"fit_params": [dict, None],
2224+
"params": [dict, None],
22242225
},
22252226
prefer_skip_nested_validation=False, # estimator is not validated yet
22262227
)
@@ -2239,6 +2240,7 @@ def validation_curve(
22392240
verbose=0,
22402241
error_score=np.nan,
22412242
fit_params=None,
2243+
params=None,
22422244
):
22432245
"""Validation curve.
22442246
@@ -2277,6 +2279,13 @@ def validation_curve(
22772279
train/test set. Only used in conjunction with a "Group" :term:`cv`
22782280
instance (e.g., :class:`GroupKFold`).
22792281
2282+
.. versionchanged:: 1.6
2283+
``groups`` can only be passed if metadata routing is not enabled
2284+
via ``sklearn.set_config(enable_metadata_routing=True)``. When routing
2285+
is enabled, pass ``groups`` alongside other metadata via the ``params``
2286+
argument instead. E.g.:
2287+
``validation_curve(..., params={'groups': groups})``.
2288+
22802289
cv : int, cross-validation generator or an iterable, default=None
22812290
Determines the cross-validation splitting strategy.
22822291
Possible inputs for cv are:
@@ -2327,7 +2336,22 @@ def validation_curve(
23272336
fit_params : dict, default=None
23282337
Parameters to pass to the fit method of the estimator.
23292338
2330-
.. versionadded:: 0.24
2339+
.. deprecated:: 1.6
2340+
This parameter is deprecated and will be removed in version 1.8. Use
2341+
``params`` instead.
2342+
2343+
params : dict, default=None
2344+
Parameters to pass to the estimator, scorer and cross-validation object.
2345+
2346+
- If `enable_metadata_routing=False` (default):
2347+
Parameters directly passed to the `fit` method of the estimator.
2348+
2349+
- If `enable_metadata_routing=True`:
2350+
Parameters safely routed to the `fit` method of the estimator, to the
2351+
scorer and to the cross-validation object. See :ref:`Metadata Routing User
2352+
Guide <metadata_routing>` for more details.
2353+
2354+
.. versionadded:: 1.6
23312355
23322356
Returns
23332357
-------
@@ -2358,11 +2382,59 @@ def validation_curve(
23582382
>>> print(f"The average test accuracy is {test_scores.mean():.2f}")
23592383
The average test accuracy is 0.81
23602384
"""
2385+
params = _check_params_groups_deprecation(fit_params, params, groups, "1.8")
23612386
X, y, groups = indexable(X, y, groups)
23622387

23632388
cv = check_cv(cv, y, classifier=is_classifier(estimator))
23642389
scorer = check_scoring(estimator, scoring=scoring)
23652390

2391+
if _routing_enabled():
2392+
router = (
2393+
MetadataRouter(owner="validation_curve")
2394+
.add(
2395+
estimator=estimator,
2396+
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
2397+
)
2398+
.add(
2399+
splitter=cv,
2400+
method_mapping=MethodMapping().add(caller="fit", callee="split"),
2401+
)
2402+
.add(
2403+
scorer=scorer,
2404+
method_mapping=MethodMapping().add(caller="fit", callee="score"),
2405+
)
2406+
)
2407+
2408+
try:
2409+
routed_params = process_routing(router, "fit", **params)
2410+
except UnsetMetadataPassedError as e:
2411+
# The default exception would mention `fit` since in the above
2412+
# `process_routing` code, we pass `fit` as the caller. However,
2413+
# the user is not calling `fit` directly, so we change the message
2414+
# to make it more suitable for this case.
2415+
unrequested_params = sorted(e.unrequested_params)
2416+
raise UnsetMetadataPassedError(
2417+
message=(
2418+
f"{unrequested_params} are passed to `validation_curve` but are not"
2419+
" explicitly set as requested or not requested for"
2420+
f" validation_curve's estimator: {estimator.__class__.__name__}."
2421+
" Call `.set_fit_request({{metadata}}=True)` on the estimator for"
2422+
f" each metadata in {unrequested_params} that you"
2423+
" want to use and `metadata=False` for not using it. See the"
2424+
" Metadata Routing User guide"
2425+
" <https://scikit-learn.org/stable/metadata_routing.html> for more"
2426+
" information."
2427+
),
2428+
unrequested_params=e.unrequested_params,
2429+
routed_params=e.routed_params,
2430+
)
2431+
2432+
else:
2433+
routed_params = Bunch()
2434+
routed_params.estimator = Bunch(fit=params)
2435+
routed_params.splitter = Bunch(split={"groups": groups})
2436+
routed_params.scorer = Bunch(score={})
2437+
23662438
parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch, verbose=verbose)
23672439
results = parallel(
23682440
delayed(_fit_and_score)(
@@ -2374,14 +2446,13 @@ def validation_curve(
23742446
test=test,
23752447
verbose=verbose,
23762448
parameters={param_name: v},
2377-
fit_params=fit_params,
2378-
# TODO(SLEP6): support score params here
2379-
score_params=None,
2449+
fit_params=routed_params.estimator.fit,
2450+
score_params=routed_params.scorer.score,
23802451
return_train_score=True,
23812452
error_score=error_score,
23822453
)
23832454
# NOTE do not change order of iteration to allow one time cv splitters
2384-
for train, test in cv.split(X, y, groups)
2455+
for train, test in cv.split(X, y, **routed_params.splitter.split)
23852456
for v in param_range
23862457
)
23872458
n_params = len(param_range)

sklearn/model_selection/tests/test_validation.py

+61-20
Original file line numberDiff line numberDiff line change
@@ -1697,7 +1697,7 @@ def test_validation_curve_cv_splits_consistency():
16971697
assert_array_almost_equal(np.array(scores3), np.array(scores1))
16981698

16991699

1700-
def test_validation_curve_fit_params():
1700+
def test_validation_curve_params():
17011701
X = np 1CF5 .arange(100).reshape(10, 10)
17021702
y = np.array([0] * 5 + [1] * 5)
17031703
clf = CheckingClassifier(expected_sample_weight=True)
@@ -1722,7 +1722,7 @@ def test_validation_curve_fit_params():
17221722
param_name="foo_param",
17231723
param_range=[1, 2, 3],
17241724
error_score="raise",
1725-
fit_params={"sample_weight": np.ones(1)},
1725+
params={"sample_weight": np.ones(1)},
17261726
)
17271727
validation_curve(
17281728
clf,
@@ -1731,7 +1731,7 @@ def test_validation_curve_fit_params():
17311731
param_name="foo_param",
17321732
param_range=[1, 2, 3],
17331733
error_score="raise",
1734-
fit_params={"sample_weight": np.ones(10)},
1734+
params={"sample_weight": np.ones(10)},
17351735
)
17361736

17371737

@@ -2482,29 +2482,54 @@ def test_cross_validate_return_indices(global_random_seed):
24822482
assert_array_equal(test_indices[split_idx], expected_test_idx)
24832483

24842484

2485-
# Tests for metadata routing in cross_val* and learning_curve
2486-
# ===========================================================
2485+
# Tests for metadata routing in cross_val* and in *curve
2486+
# ======================================================
24872487

24882488

24892489
# TODO(1.6): remove `cross_validate` and `cross_val_predict` from this test in 1.6 and
2490-
# `learning_curve` in 1.8
2491-
@pytest.mark.parametrize("func", [cross_validate, cross_val_predict, learning_curve])
2492-
def test_fit_param_deprecation(func):
2490+
# `learning_curve` and `validation_curve` in 1.8
2491+
@pytest.mark.parametrize(
2492+
"func, extra_args",
2493+
[
2494+
(cross_validate, {}),
2495+
(cross_val_score, {}),
2496+
(cross_val_predict, {}),
2497+
(learning_curve, {}),
2498+
(validation_curve, {"param_name": "alpha", "param_range": np.array([1])}),
2499+
],
2500+
)
2501+
def test_fit_param_deprecation(func, extra_args):
24932502
"""Check that we warn about deprecating `fit_params`."""
24942503
with pytest.warns(FutureWarning, match="`fit_params` is deprecated"):
2495-
func(estimator=ConsumingClassifier(), X=X, y=y, cv=2, fit_params={})
2504+
func(
2505+
estimator=ConsumingClassifier(), X=X, y=y, cv=2, fit_params={}, **extra_args
2506+
)
24962507

24972508
with pytest.raises(
24982509
ValueError, match="`params` and `fit_params` cannot both be provided"
24992510
):
2500-
func(estimator=ConsumingClassifier(), X=X, y=y, fit_params={}, params={})
2511+
func(
2512+
estimator=ConsumingClassifier(),
2513+
X=X,
2514+
y=y,
2515+
fit_params={},
2516+
params={},
2517+
**extra_args,
2518+
)
25012519

25022520

25032521
@pytest.mark.usefixtures("enable_slep006")
25042522
@pytest.mark.parametrize(
2505-
"func", [cross_validate, cross_val_score, cross_val_predict, learning_curve]
2523+
"func, extra_args",
2524+
[
2525+
(cross_validate, {}),
2526+
(cross_val_score, {}),
2527+
(cross_val_predict, {}),
2528+
(learning_curve, {}),
2529+
(validation_curve, {"param_name": "alpha", "param_range": np.array([1])}),
2530+
],
25062531
)
2507-
def test_groups_with_routing_validation(func):
2532+
def test_groups_with_routing_validation(func, extra_args):
25082533
"""Check that we raise an error if `groups` are passed to the cv method instead
25092534
of `params` when metadata routing is enabled.
25102535
"""
@@ -2514,14 +2539,22 @@ def test_groups_with_routing_validation(func):
25142539
X=X,
25152540
y=y,
25162541
groups=[],
2542+
**extra_args,
25172543
)
25182544

25192545

25202546
@pytest.mark.usefixtures("enable_slep006")
25212547
@pytest.mark.parametrize(
2522-
"func", [cross_validate, cross_val_score, cross_val_predict, learning_curve]
2548+
"func, extra_args",
2549+
[
2550+
(cross_validate, {}),
2551+
(cross_val_score, {}),
2552+
(cross_val_predict, {}),
2553+
(learning_curve, {}),
2554+
(validation_curve, {"param_name": "alpha", "param_range": np.array([1])}),
2555+
],
25232556
)
2524-
def test_passed_unrequested_metadata(func):
2557+
def test_passed_unrequested_metadata(func, extra_args):
25252558
"""Check that we raise an error when passing metadata that is not
25262559
requested."""
25272560
err_msg = re.escape("but are not explicitly set as requested or not requested")
@@ -2531,14 +2564,22 @@ def test_passed_unrequested_metadata(func):
25312564
X=X,
25322565
y=y,
25332566
params=dict(metadata=[]),
2567+
**extra_args,
25342568
)
25352569

25362570

25372571
@pytest.mark.usefixtures("enable_slep006")
25382572
@pytest.mark.parametrize(
2539-
"func", [cross_validate, cross_val_score, cross_val_predict, learning_curve]
2573+
"func, extra_args",
2574+
[
2575+
(cross_validate, {}),
2576+
(cross_val_score, {}),
2577+
(cross_val_predict, {}),
2578+
(learning_curve, {}),
2579+
(validation_curve, {"param_name": "alpha", "param_range": np.array([1])}),
2580+
],
25402581
)
2541-
def test_validation_functions_routing(func):
2582+
def test_validation_functions_routing(func, extra_args):
25422583
"""Check that the respective cv method is properly dispatching the metadata
25432584
to the consumer."""
25442585
scorer_registry = _Registry()
@@ -2563,12 +2604,11 @@ def test_validation_functions_routing(func):
25632604
fit_sample_weight = rng.rand(n_samples)
25642605
fit_metadata = rng.rand(n_samples)
25652606

2566-
extra_params = {
2607+
scoring_args = {
25672608
cross_validate: dict(scoring=dict(my_scorer=scorer, accuracy="accuracy")),
2568-
# cross_val_score and learning_curve don't support multiple scorers:
25692609
cross_val_score: dict(scoring=scorer),
25702610
learning_curve: dict(scoring=scorer),
2571-
# cross_val_predict doesn't need a scorer
2611+
validation_curve: dict(scoring=scorer),
25722612
cross_val_predict: dict(),
25732613
}
25742614

@@ -2590,7 +2630,8 @@ def test_validation_functions_routing(func):
25902630
X=X,
25912631
y=y,
25922632
cv=splitter,
2593-
**extra_params[func],
2633+
**scoring_args[func],
2634+
**extra_args,
25942635
params=params,
25952636
)
25962637

0 commit comments

Comments
 (0)
0