@@ -298,7 +298,7 @@ def enable_slep006():
298
298
"y" : y ,
299
299
"preserves_metadata" : "subset" ,
300
300
"estimator_routing_methods" : ["fit" , "predict" , "score" ],
301
- "requests_set_together " : {"fit" : ["score" ]},
301
+ "method_mapping " : {"fit" : ["fit" , "score" ]},
302
302
},
303
303
{
304
304
"metaestimator" : IterativeImputer ,
@@ -339,12 +339,9 @@ def enable_slep006():
339
339
to the splitter
340
340
- method_args: a dict of dicts, defining extra arguments needed to be passed to
341
341
methods, such as passing `classes` to `partial_fit`.
342
- - requests_set_together: a dict that defines which set_{method}_requests need
343
- to be set together with the key; used in case a router routes to different
344
- methods from the sub-estimator from within the same meta-estimator's method.
345
- For instance, {"fit": ["score"]} would signal that
346
- `estimator.set_fit_request` premises `estimator.set_score_request` to be set
347
- as well.
342
+ - method_mapping: a dict of the form `{caller: [callee1, ...]}` which signals
343
+ which `.set_{method}_request` methods should be called to set request values.
344
+ If not present, a one-to-one mapping is assumed.
348
345
"""
349
346
350
347
# IDs used by pytest to get meaningful verbose messages when running the tests
@@ -442,13 +439,36 @@ def get_init_args(metaestimator_info, sub_estimator_consumes):
442
439
)
443
440
444
441
445
- def set_requests (estimator , methods , metadata_name ):
446
8000
td>- """Call `set_fit_request` on a list of methods from the sub-estimator."""
447
- for method in methods :
448
- set_request_for_method = getattr (estimator , f"set_{ method } _request" )
449
- set_request_for_method (** {metadata_name : True })
450
- if is_classifier (estimator ) and method == "partial_fit" :
451
- set_request_for_method (classes = True )
442
+ def set_requests (estimator , * , method_mapping , methods , metadata_name , value = True ):
443
+ """Call `set_{method}_request` on a list of methods from the sub-estimator.
444
+
445
+ Parameters
446
+ ----------
447
+ estimator : BaseEstimator
448
+ The estimator for which `set_{method}_request` methods are called.
449
+
450
+ method_mapping : dict
451
+ The method mapping in the form of `{caller: [callee, ...]}`.
452
+ If a "caller" is not present in the method mapping, a one-to-one mapping is
453
+ assumed.
454
+
455
+ methods : list of str
456
+ The list of methods as "caller"s for which the request for the child should
457
+ be set.
458
+
459
+ metadata_name : str
460
+ The name of the metadata to be routed, usually either `"metadata"` or
461
+ `"sample_weight"` in our tests.
462
+
463
+ value : None, bool, or str
464
+ The request value to be set, by default it's `True`
465
+ """
466
+ for caller in methods :
467
+ for callee in method_mapping .get (caller , [caller ]):
468
+ set_request_for_method = getattr (estimator , f"set_{ callee } _request" )
469
+ set_request_for_method (** {metadata_name : value })
470
+ if is_classifier (estimator ) and callee == "partial_fit" :
471
+ set_request_for_method (classes = True )
452
472
453
473
454
474
@pytest .mark .parametrize ("estimator" , UNSUPPORTED_ESTIMATORS )
@@ -531,13 +551,26 @@ def test_error_on_missing_requests_for_sub_estimator(metaestimator):
531
551
method = getattr (instance , method_name )
532
552
if "fit" not in method_name :
533
553
# set request on fit
534
- set_requests (estimator , methods = ["fit" ], metadata_name = key )
535
- # make sure error message corresponding to `method_name`
536
- # is used for test
537
- if method_name != "score" :
538
- set_requests (estimator , methods = ["score" ], metadata_name = key )
554
+ set_requests (
555
+ estimator ,
556
+ method_mapping = metaestimator .get ("method_mapping" , {}),
557
+ methods = ["fit" ],
558
+ metadata_name = key ,
559
+ )
539
560
instance .fit (X , y , ** method_kwargs )
540
561
try :
562
+ # making sure the requests are unset, in case they were set as a
563
+ # side effect of setting them for fit. For instance, if method
564
+ # mapping for fit is: `"fit": ["fit", "score"]`, that would mean
565
+ # calling `.score` here would not raise, because we have already
566
+ # set request value for child estimator's `score`.
567
+ set_requests (
568
+ estimator ,
569
+ method_mapping = metaestimator .get ("method_mapping" , {}),
570
+ methods = ["fit" ],
571
+ metadata_name = key ,
572
+ value = None ,
573
+ )
541
574
# `fit` and `partial_fit` accept y, others don't.
542
575
method (X , y , ** method_kwargs )
543
576
except TypeError :
@@ -557,7 +590,7 @@ def test_setting_request_on_sub_estimator_removes_error(metaestimator):
557
590
X = metaestimator ["X" ]
558
591
y = metaestimator ["y" ]
559
592
routing_methods = metaestimator ["estimator_routing_methods" ]
560
- requests_set_together = metaestimator .get ("requests_set_together " , {})
593
+ method_mapping = metaestimator .get ("method_mapping " , {})
561
594
preserves_metadata = metaestimator .get ("preserves_metadata" , True )
562
595
563
596
for method_name in routing_methods :
@@ -569,16 +602,19 @@ def test_setting_request_on_sub_estimator_removes_error(metaestimator):
569
602
metaestimator , sub_estimator_consumes = True
570
603
)
571
604
if scorer :
572
- set_requests (scorer , methods = ["score" ], metadata_name = key )
605
+ set_requests (
606
+ scorer , method_mapping = {}, methods = ["score" ], metadata_name = key
607
+ )
573
608
if cv :
574
609
cv .set_split_request (groups = True , metadata = True )
575
610
576
611
# `set_{method}_request({metadata}==True)` on the underlying objects
577
- set_requests (estimator , methods = [method_name ], metadata_name = key )
578
- if requests_set_together :
579
- set_requests (
580
- estimator , methods = requests_set_together ["fit" ], metadata_name = key
581
- )
612
+ set_requests (
613
+ estimator ,
614
+ method_mapping = method_mapping ,
615
+ methods = [method_name ],
616
+ metadata_name = key ,
617
+ )
582
618
583
619
instance = cls (** kwargs )
584
620
method = getattr (instance , method_name )
@@ -587,13 +623,12 @@ def test_setting_request_on_sub_estimator_removes_error(metaestimator):
587
623
)
588
624
if "fit" not in method_name :
589
625
# fit before calling method
590
-
8000
set_requests (estimator , methods = ["fit" ], metadata_name = key )
591
- if requests_set_together :
592
- set_requests (
593
- estimator ,
594
- methods = requests_set_together ["fit" ],
595
- metadata_name = key ,
596
- )
626
+ set_requests (
627
+ estimator ,
628
+ method_mapping = metaestimator .get ("method_mapping" , {}),
629
+ methods = ["fit" ],
630
+ metadata_name = key ,
631
+ )
597
632
instance .fit (X , y , ** method_kwargs , ** extra_method_args )
598
633
try :
599
634
# `fit` and `partial_fit` accept y, others don't.
0 commit comments