8000 MNT metadata routing: remove `MethodMapping.from_str()` and sort `cal… · scikit-learn/scikit-learn@2bafd7b · GitHub
[go: up one dir, main page]

Skip to content

Commit 2bafd7b

Browse files
StefanieSengeradrinjalaliglemaitre
authored
MNT metadata routing: remove MethodMapping.from_str() and sort caller, callee in MethodPair() (#28422)
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com> Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
1 parent f61dd6c commit 2bafd7b

File tree

14 files changed

+112
-128
lines changed
Collapse file tree

14 files changed

+112
-128
lines changed

examples/miscellaneous/plot_metadata_routing.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,9 @@ def get_metadata_routing(self):
167167
router = MetadataRouter(owner=self.__class__.__name__).add(
168168
estimator=self.estimator,
169169
method_mapping=MethodMapping()
170-
.add(callee="fit", caller="fit")
171-
.add(callee="predict", caller="predict")
172-
.add(callee="score", caller="score"),
170+
.add(caller="fit", callee="fit")
171+
.add(caller="predict", callee="predict")
172+
.add(caller="score", callee="score"),
173173
)
174174
return router
175175

@@ -356,9 +356,9 @@ def get_metadata_routing(self):
356356
.add(
357357
estimator=self.estimator,
358358
method_mapping=MethodMapping()
359-
.add(callee="fit", caller="fit")
360-
.add(callee="predict", caller="predict")
361-
.add(callee="score", caller="score"),
359+
.add(caller="fit", callee="fit")
360+
.add(caller="predict", callee="predict")
361+
.add(caller="score", callee="score"),
362362
)
363363
)
364364
return router
@@ -488,16 +488,16 @@ def get_metadata_routing(self):
488488
# The metadata is routed such that it retraces how
489489
# `SimplePipeline` internally calls the transformer's `fit` and
490490
# `transform` methods in its own methods (`fit` and `predict`).
491-
.add(callee="fit", caller="fit")
492-
.add(callee="transform", caller="fit")
493-
.add(callee="transform", caller="predict"),
491+
.add(caller="fit", callee="fit")
492+
.add(caller="fit", callee="transform")
493+
.add(caller="predict", callee="transform"),
494494
)
495495
# We add the routing for the classifier.
496496
.add(
497497
classifier=self.classifier,
498498
method_mapping=MethodMapping()
499-
.add(callee="fit", caller="fit")
500-
.add(callee="predict", caller="predict"),
499+
.add(caller="fit", callee="fit")
500+
.add(caller="predict", callee="predict"),
501501
)
502502
)
503503
return router
@@ -612,7 +612,7 @@ def fit(self, X, y, **fit_params):
612612
def get_metadata_routing(self):
613613
router = MetadataRouter(owner=self.__class__.__name__).add(
614614
estimator=self.estimator,
615-
method_mapping=MethodMapping().add(callee="fit", caller="fit"),
615+
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
616616
)
617617
return router
618618

@@ -651,7 +651,7 @@ def get_metadata_routing(self):
651651
.add_self_request(self)
652652
.add(
653653
estimator=self.estimator,
654-
method_mapping=MethodMapping().add(callee="fit", caller="fit"),
654+
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
655655
)
656656
)
657657
return router
@@ -692,7 +692,7 @@ def predict(self, X):
692692
print(w.message)
693693

694694
# %%
695-
# In the end, we disable the configuration flag for metadata routing:
695+
# At the end we disable the configuration flag for metadata routing:
696696

697697
set_config(enable_metadata_routing=False)
698698

sklearn/calibration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,11 +523,11 @@ def get_metadata_routing(self):
523523
.add_self_request(self)
524524
.add(
525525
estimator=self._get_estimator(),
526-
method_mapping=MethodMapping().add(callee="fit", caller="fit"),
526+
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
527527
)
528528
.add(
529529
splitter=self.cv,
530-
method_mapping=MethodMapping().add(callee="split", caller="fit"),
530+
method_mapping=MethodMapping().add(caller="fit", callee="split"),
531531
)
532532
)
533533
return router

sklearn/feature_selection/_from_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,8 @@ def get_metadata_routing(self):
513513
router = MetadataRouter(owner=self.__class__.__name__).add(
514514
estimator=self.estimator,
515515
method_mapping=MethodMapping()
516-
.add(callee="partial_fit", caller="partial_fit")
517-
.add(callee="fit", caller="fit"),
516+
.add(caller="partial_fit", callee="partial_fit")
517+
.add(caller="fit", callee="fit"),
518518
)
519519
return router
520520

sklearn/linear_model/_coordinate_descent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1860,7 +1860,7 @@ def get_metadata_routing(self):
18601860
.add_self_request(self)
18611861
.add(
18621862
splitter=check_cv(self.cv),
1863-
method_mapping=MethodMapping().add(callee="split", caller="fit"),
1863+
method_mapping=MethodMapping().add(caller="fit", callee="split"),
18641864
)
18651865
)
18661866
return router

sklearn/linear_model/_least_angle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1821,7 +1821,7 @@ def get_metadata_routing(self):
18211821
"""
18221822
router = MetadataRouter(owner=self.__class__.__name__).add(
18231823
splitter=check_cv(self.cv),
1824-
method_mapping=MethodMapping().add(callee="split", caller="fit"),
1824+
method_mapping=MethodMapping().add(caller="fit", callee="split"),
18251825
)
18261826
return router
18271827

sklearn/linear_model/_logistic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2166,13 +2166,13 @@ def get_metadata_routing(self):
21662166
.add_self_request(self)
21672167
.add(
21682168
splitter=self.cv,
2169-
method_mapping=MethodMapping().add(callee="split", caller="fit"),
2169+
method_mapping=MethodMapping().add(caller="fit", callee="split"),
21702170
)
21712171
.add(
21722172
scorer=self._get_scorer(),
21732173
method_mapping=MethodMapping()
2174-
.add(callee="score", caller="score")
2175-
.add(callee="score", caller="fit"),
2174+
.add(caller="score", callee="score")
2175+
.add(caller="fit", callee="score"),
21762176
)
21772177
)
21782178
return router

sklearn/linear_model/_omp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1116,6 +1116,6 @@ def get_metadata_routing(self):
11161116

11171117
router = MetadataRouter(owner=self.__class__.__name__).add(
11181118
splitter=self.cv,
1119-
method_mapping=MethodMapp F438 ing().add(callee="split", caller="fit"),
1119+
method_mapping=MethodMapping().add(caller="fit", callee="split"),
11201120
)
11211121
return router

sklearn/metrics/_scorer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ..utils.metadata_routing import (
3333
MetadataRequest,
3434
MetadataRouter,
35+
MethodMapping,
3536
_MetadataRequester,
3637
_raise_for_params,
3738
_routing_enabled,
@@ -188,7 +189,8 @@ def get_metadata_routing(self):
188189
routing information.
189190
"""
190191
return MetadataRouter(owner=self.__class__.__name__).add(
191-
**self._scorers, method_mapping="score"
192+
**self._scorers,
193+
method_mapping=MethodMapping().add(caller="score", callee="score"),
192194
)
193195

194196

sklearn/metrics/tests/test_score_objects.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
assert_array_equal,
6262
ignore_warnings,
6363
)
64-
from sklearn.utils.metadata_routing import MetadataRouter
64+
from sklearn.utils.metadata_routing import MetadataRouter, MethodMapping
6565

6666
REGRESSION_SCORERS = [
6767
"d2_absolute_error_score",
@@ -1233,7 +1233,8 @@ def test_scorer_metadata_request(name):
12331233
# make sure putting the scorer in a router doesn't request anything by
12341234
# default
12351235
router = MetadataRouter(owner="test").add(
1236-
method_mapping="score", scorer=get_scorer(name)
1236+
scorer=get_scorer(name),
1237+
method_mapping=MethodMapping().add(caller="score", callee="score"),
12371238
)
12381239
# make sure `sample_weight` is refused if passed.
12391240
with pytest.raises(TypeError, match="got unexpected argument"):
@@ -1244,7 +1245,8 @@ def test_scorer_metadata_request(name):
12441245

12451246
# make sure putting weighted_scorer in a router requests sample_weight
12461247
router = MetadataRouter(owner="test").add(
1247-
scorer=weighted_scorer, method_mapping="score"
1248+
scorer=weighted_scorer,
1249+
method_mapping=MethodMapping().add(caller="score", callee="score"),
12481250
)
12491251
router.validate_metadata(params={"sample_weight": 1}, method="score")
12501252
routed_params = router.route_params(params={"sample_weight": 1}, caller="score")

sklearn/multiclass.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -619,8 +619,8 @@ def get_metadata_routing(self):
619619
.add(
620620
estimator=self.estimator,
621621
method_mapping=MethodMapping()
622-
.add(callee="fit", caller="fit")
623-
.add(callee="partial_fit", caller="partial_fit"),
622+
.add(caller="fit", callee="fit")
623+
.add(caller="partial_fit", callee="partial_fit"),
624624
)
625625
)
626626
return router
@@ -1018,8 +1018,8 @@ def get_metadata_routing(self):
10181018
.add(
10191019
estimator=self.estimator,
10201020
method_mapping=MethodMapping()
1021-
.add(callee="fit", caller="fit")
1022-
.add(callee="partial_fit", caller="partial_fit"),
1021+
.add(caller="fit", callee="fit")
1022+
.add(caller="partial_fit", callee="partial_fit"),
10231023
)
10241024
)
10251025
return router
@@ -1264,6 +1264,6 @@ def get_metadata_routing(self):
12641264

12651265
router = MetadataRouter(owner=self.__class__.__name__).add(
12661266
estimator=self.estimator,
1267-
method_mapping=MethodMapping().add(callee="fit", caller="fit"),
1267+
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
12681268
)
12691269
return router

0 commit comments

Comments
 (0)
0