8000 Add method_mapping to the common test definitions of the metaestimators · scikit-learn/scikit-learn@d3ac747 · GitHub
[go: up one dir, main page]

Skip to content

Commit d3ac747

Browse files
committed
Add method_mapping to the common test definitions of the metaestimators
1 parent 4f45a6b commit d3ac747

File tree

1 file changed

+68
-33
lines changed

1 file changed

+68
-33
lines changed

sklearn/tests/test_metaestimators_metadata_routing.py

Lines changed: 68 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def enable_slep006():
298298
"y": y,
299299
"preserves_metadata": "subset",
300300
"estimator_routing_methods": ["fit", "predict", "score"],
301-
"requests_set_together": {"fit": ["score"]},
301+
"method_mapping": {"fit": ["fit", "score"]},
302302
},
303303
{
304304
"metaestimator": IterativeImputer,
@@ -339,12 +339,9 @@ def enable_slep006():
339339
to the splitter
340340
- method_args: a dict of dicts, defining extra arguments needed to be passed to
341341
methods, such as passing `classes` to `partial_fit`.
342-
- requests_set_together: a dict that defines which set_{method}_requests need
343-
to be set together with the key; used in case a router routes to different
344-
methods from the sub-estimator from within the same meta-estimator's method.
345-
For instance, {"fit": ["score"]} would signal that
346-
`estimator.set_fit_request` premises `estimator.set_score_request` to be set
347-
as well.
342+
- method_mapping: a dict of the form `{caller: [callee1, ...]}` which signals
343+
which `.set_{method}_request` methods should be called to set request values.
344+
If not present, a one-to-one mapping is assumed.
348345
"""
349346

350347
# IDs used by pytest to get meaningful verbose messages when running the tests
@@ -442,13 +439,36 @@ def get_init_args(metaestimator_info, sub_estimator_consumes):
442439
)
443440

444441

445-
def set_requests(estimator, methods, metadata_name):
446-
"""Call `set_fit_request` on a list of methods from the sub-estimator."""
447-
for method in methods:
448-
set_request_for_method = getattr(estimator, f"set_{method}_request")
449-
set_request_for_method(**{metadata_name: True})
450-
if is_classifier(estimator) and method == "partial_fit":
451-
set_request_for_method(classes=True)
442+
def set_requests(estimator, *, method_mapping, methods, metadata_name, value=True):
443+
"""Call `set_{method}_request` on a list of methods from the sub-estimator.
444+
445+
Parameters
446+
----------
447+
estimator : BaseEstimator
448+
The estimator for which `set_{method}_request` methods are called.
449+
450+
method_mapping : dict
451+
The method mapping in the form of `{caller: [callee, ...]}`.
452+
If a "caller" is not present in the method mapping, a one-to-one mapping is
453+
assumed.
454+
455+
methods : list of str
456+
The list of methods as "caller"s for which the request for the child should
457+
be set.
458+
459+
metadata_name : str
460+
The name of the metadata to be routed, usually either `"metadata"` or
461+
`"sample_weight"` in our tests.
462+
463+
value : None, bool, or str
464+
The request value to be set, by default it's `True`
465+
"""
466+
for caller in methods:
467+
for callee in method_mapping.get(caller, [caller]):
468+
set_request_for_method = getattr(estimator, f"set_{callee}_request")
469+
set_request_for_method(**{metadata_name: value})
470+
if is_classifier(estimator) and callee == "partial_fit":
471+
set_request_for_method(classes=True)
452472

453473

454474
@pytest.mark.parametrize("estimator", UNSUPPORTED_ESTIMATORS)
@@ -531,13 +551,26 @@ def test_error_on_missing_requests_for_sub_estimator(metaestimator):
531551
method = getattr(instance, method_name)
532552
if "fit" not in method_name:
533553
# set request on fit
534-
set_requests(estimator, methods=["fit"], metadata_name=key)
535-
# make sure error message corresponding to `method_name`
536-
# is used for test
537-
if method_name != "score":
538-
set_requests(estimator, methods=["score"], metadata_name=key)
554+
set_requests(
555+
estimator,
556+
method_mapping=metaestimator.get("method_mapping", {}),
557+
methods=["fit"],
558+
metadata_name=key,
559+
)
539560
instance.fit(X, y, **method_kwargs)
540561
try:
562+
# making sure the requests are unset, in case they were set as a
563+
# side effect of setting them for fit. For instance, if method
564+
# mapping for fit is: `"fit": ["fit", "score"]`, that would mean
565+
# calling `.score` here would not raise, because we have already
566+
# set request value for child estimator's `score`.
567+
set_requests(
568+
estimator,
569+
method_mapping=metaestimator.get("method_mapping", {}),
570+
methods=["fit"],
571+
metadata_name=key,
572+
value=None,
573+
)
541574
# `fit` and `partial_fit` accept y, others don't.
542575
method(X, y, **method_kwargs)
543576
except TypeError:
@@ -557,7 +590,7 @@ def test_setting_request_on_sub_estimator_removes_error(metaestimator):
557590
X = metaestimator["X"]
558591
y = metaestimator["y"]
559592
routing_methods = metaestimator["estimator_routing_methods"]
560-
requests_set_together = metaestimator.get("requests_set_together", {})
593+
method_mapping = metaestimator.get("method_mapping", {})
561594
preserves_metadata = metaestimator.get("preserves_metadata", True)
562595

563596
for method_name in routing_methods:
@@ -569,16 +602,19 @@ def test_setting_request_on_sub_estimator_removes_error(metaestimator):
569602
metaestimator, sub_estimator_consumes=True
570603
)
571604
if scorer:
572-
set_requests(scorer, methods=["score"], metadata_name=key)
605+
set_requests(
606+
scorer, method_mapping={}, methods=["score"], metadata_name=key
607+
)
573608
if cv:
574609
cv.set_split_request(groups=True, metadata=True)
575610

576611
# `set_{method}_request({metadata}==True)` on the underlying objects
577-
set_requests(estimator, methods=[method_name], metadata_name=key)
578-
if requests_set_together:
579-
set_requests(
580-
estimator, methods=requests_set_together["fit"], metadata_name=key
581-
)
612+
set_requests(
613+
estimator,
614+
method_mapping=method_mapping,
615+
methods=[method_name],
616+
metadata_name=key,
617+
)
582618

583619
instance = cls(**kwargs)
584620
method = getattr(instance, method_name)
@@ -587,13 +623,12 @@ def test_setting_request_on_sub_estimator_removes_error(metaestimator):
587623
)
588624
if "fit" not in method_name:
589625
# fit before calling method
590-
8000 set_requests(estimator, methods=["fit"], metadata_name=key)
591-
if requests_set_together:
592-
set_requests(
593-
estimator,
594-
methods=requests_set_together["fit"],
595-
metadata_name=key,
596-
)
626+
set_requests(
627+
estimator,
628+
method_mapping=metaestimator.get("method_mapping", {}),
629+
methods=["fit"],
630+
metadata_name=key,
631+
)
597632
instance.fit(X, y, **method_kwargs, **extra_method_args)
598633
try:
599634
# `fit` and `partial_fit` accept y, others don't.

0 commit comments

Comments
 (0)
0