8000 MNT (SLEP6) remove other_params from provess_routing (#26909) · punndcoder28/scikit-learn@36d019f · GitHub
[go: up one dir, main page]

Skip to content

Commit 36d019f

Browse files
adrinjalalithomasjpfan
authored andcommitted
MNT (SLEP6) remove other_params from provess_routing (scikit-learn#26909)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent 419458d commit 36d019f

File tree

10 files changed

+60
-68
lines changed

10 files changed

+60
-68
lines changed

doc/whats_new/v1.4.rst

+5
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ Changelog
6666
- |Enhancement| :func:`base.clone` now supports `dict` as input and creates a
6767
copy. :pr:`26786` by `Adrin Jalali`_.
6868

69+
- |API|:func:`~utils.metadata_routing.process_routing` now has a different
70+
signature. The first two (the object and the method) are positional only,
71+
and all metadata are passed as keyword arguments. :pr:`26909` by `Adrin
72+
Jalali`_.
73+
6974
:mod:`sklearn.cross_decomposition`
7075
..................................
7176

examples/miscellaneous/plot_metadata_routing.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def get_metadata_routing(self):
447447
return router
448448

449449
def fit(self, X, y, **fit_params):
450-
params = process_routing(self, "fit", fit_params)
450+
params = process_routing(self, "fit", **fit_params)
451451

452452
self.transformer_ = clone(self.transformer).fit(X, y, **params.transformer.fit)
453453
X_transformed = self.transformer_. B41A transform(X, **params.transformer.transform)
@@ -458,7 +458,7 @@ def fit(self, X, y, **fit_params):
458458
return self
459459

460460
def predict(self, X, **predict_params):
461-
params = process_routing(self, "predict", predict_params)
461+
params = process_routing(self, "predict", **predict_params)
462462

463463
X_transformed = self.transformer_.transform(X, **params.transformer.transform)
464464
return self.classifier_.predict(X_transformed, **params.classifier.predict)
@@ -543,7 +543,7 @@ def __init__(self, estimator):
543543
self.estimator = estimator
544544

545545
def fit(self, X, y, **fit_params):
546-
params = process_routing(self, "fit", fit_params)
546+
params = process_routing(self, "fit", **fit_params)
547547
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
548548

549549
def get_metadata_routing(self):
@@ -572,7 +572,7 @@ def __init__(self, estimator):
572572
self.estimator = estimator
573573

574574
def fit(self, X, y, sample_weight=None, **fit_params):
575-
params = process_routing(self, "fit", fit_params, sample_weight=sample_weight)
575+
params = process_routing(self, "fit", sample_weight=sample_weight, **fit_params)
576576
check_metadata(self, sample_weight=sample_weight)
577577
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
578578

sklearn/calibration.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -378,10 +378,10 @@ def fit(self, X, y, sample_weight=None, **fit_params):
378378

379379
if _routing_enabled():
380380
routed_params = process_routing(
381-
obj=self,
382-
method="fit",
381+
self,
382+
"fit",
383383
sample_weight=sample_weight,
384-
other_params=fit_params,
384+
**fit_params,
385385
)
386386
else:
387387
# sample_weight checks

sklearn/linear_model/_logistic.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1859,10 +1859,10 @@ def fit(self, X, y, sample_weight=None, **params):
18591859

18601860
if _routing_enabled():
18611861
routed_params = process_routing(
1862-
obj=self,
1863-
method="fit",
1862+
self,
1863+
"fit",
18641864
sample_weight=sample_weight,
1865-
other_params=params,
1865+
**params,
18661866
)
18671867
else:
18681868
routed_params = Bunch()
@@ -2150,10 +2150,10 @@ def score(self, X, y, sample_weight=None, **score_params):
21502150
scoring = self._get_scorer()
21512151
if _routing_enabled():
21522152
routed_params = process_routing(
2153-
obj=self,
2154-
method="score",
2153+
self,
2154+
"score",
21552155
sample_weight=sample_weight,
2156-
other_params=score_params,
2156+
**score_params,
21572157
)
21582158
else:
21592159
routed_params = Bunch()

sklearn/metrics/_scorer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __call__(self, estimator, *args, **kwargs):
124124
cached_call = partial(_cached_call, cache)
125125

126126
if _routing_enabled():
127-
routed_params = process_routing(self, "score", kwargs)
127+
routed_params = process_routing(self, "score", **kwargs)
128128
else:
129129
# they all get the same args, and they all get them all
130130
routed_params = Bunch(

sklearn/multioutput.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,10 @@ def partial_fit(self, X, y, classes=None, sample_weight=None, **partial_fit_para
163163

164164
if _routing_enabled():
165165
routed_params = process_routing(
166-
obj=self,
167-
method="partial_fit",
168-
other_params=partial_fit_params,
166+
self,
167+
"partial_fit",
169168
sample_weight=sample_weight,
169+
**partial_fit_params,
170170
)
171171
else:
172172
if sample_weight is not None and not has_fit_parameter(
@@ -249,10 +249,10 @@ def fit(self, X, y, sample_weight=None, **fit_params):
249249

250250
if _routing_enabled():
251251
routed_params = process_routing(
252-
obj=self,
253-
method="fit",
254-
other_params=fit_params,
252+
self,
253+
"fit",
255254
sample_weight=sample_weight,
255+
**fit_params,
256256
)
257257
else:
258258
if sample_weight is not None and not has_fit_parameter(
@@ -706,9 +706,7 @@ def fit(self, X, Y, **fit_params):
706706
del Y_pred_chain
707707

708708
if _routing_enabled():
709-
routed_params = process_routing(
710-
obj=self, method="fit", other_params=fit_params
711-
)
709+
routed_params = process_routing(self, "fit", **fit_params)
712710
else:
713711
routed_params = Bunch(estimator=Bunch(fit=fit_params))
714712

sklearn/pipeline.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -334,9 +334,7 @@ def _log_message(self, step_idx):
334334

335335
def _check_method_params(self, method, props, **kwargs):
336336
if _routing_enabled():
337-
routed_params = process_routing(
338-
self, method=method, other_params=props, **kwargs
339-
)
337+
routed_params = process_routing(self, method, **props, **kwargs)
340338
return routed_params
341339
else:
342340
fit_params_steps = Bunch(
@@ -586,7 +584,7 @@ def predict(self, X, **params):
586584
return self.steps[-1][1].predict(Xt, **params)
587585

588586
# metadata routing enabled
589-
routed_params = process_routing(self, "predict", other_params=params)
587+
routed_params = process_routing(self, "predict", **params)
590588
for _, name, transform in self._iter(with_final=False):
591589
Xt = transform.transform(Xt, **routed_params[name].transform)
592590
return self.steps[-1][1].predict(Xt, **routed_params[self.steps[-1][0]].predict)
@@ -706,7 +704,7 @@ def predict_proba(self, X, **params):
706704
return self.steps[-1][1].predict_proba(Xt, **params)
707705

708706
# metadata routing enabled
709-
routed_params = process_routing(self, "predict_proba", other_params=params)
707+
routed_params = process_routing(self, "predict_proba", **params)
710708
for _, name, transform in self._iter(with_final=False):
711709
Xt = transform.transform(Xt, **routed_params[name].transform)
712710
return self.steps[-1][1].predict_proba(
@@ -747,7 +745,7 @@ def decision_function(self, X, **params):
747745

748746
# not branching here since params is only available if
749747
# enable_metadata_routing=True
750-
routed_params = process_routing(self, "decision_function", other_params=params)
748+
routed_params = process_routing(self, "decision_function", **params)
751749

752750
Xt = X
753751
for _, name, transform in self._iter(with_final=False):
@@ -833,7 +831,7 @@ def predict_log_proba(self, X, **params):
833831
return self.steps[-1][1].predict_log_proba(Xt, **params)
834832

835833
# metadata routing enabled
836-
routed_params = process_routing(self, "predict_log_proba", other_params=params)
834+
routed_params = process_routing(self, "predict_log_proba", **params)
837835
for _, name, transform in self._iter(with_final=False):
838836
Xt = transform.transform(Xt, **routed_params[name].transform)
839837
return self.steps[-1][1].predict_log_proba(
@@ -882,7 +880,7 @@ def transform(self, X, **params):
882880

883881
# not branching here since params is only available if
884882
# enable_metadata_routing=True
885-
routed_params = process_routing(self, "transform", other_params=params)
883+
routed_params = process_routing(self, "transform", **params)
886884
Xt = X
887885
for _, name, transform in self._iter():
888886
Xt = transform.transform(Xt, **routed_params[name].transform)
@@ -925,7 +923,7 @@ def inverse_transform(self, Xt, **params):
925923

926924
# we don't have to branch here, since params is only non-empty if
927925
# enable_metadata_routing=True.
928-
routed_params = process_routing(self, "inverse_transform", other_params=params)
926+
routed_params = process_routing(self, "inverse_transform", **params)
929927
reverse_iter = reversed(list(self._iter()))
930928
for _, name, transform in reverse_iter:
931929
Xt = transform.inverse_transform(
@@ -981,7 +979,7 @@ def score(self, X, y=None, sample_weight=None, **params):
981979

982980
# metadata routing is enabled.
983981
routed_params = process_routing(
984-
self, "score", sample_weight=sample_weight, other_params=params
982+
self, "score", sample_weight=sample_weight, **params
985983
)
986984

987985
Xt = X

sklearn/tests/metadata_routing_common.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def __init__(self, estimator):
323323
self.estimator = estimator
324324

325325
def fit(self, X, y, **fit_params):
326-
params = process_routing(self, "fit", fit_params)
326+
params = process_routing(self, "fit", **fit_params)
327327
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
328328

329329
def get_metadata_routing(self):
@@ -345,12 +345,12 @@ def fit(self, X, y, sample_weight=None, **fit_params):
345345
self.registry.append(self)
346346

347347
record_metadata(self, "fit", sample_weight=sample_weight)
348-
params = process_routing(self, "fit", fit_params, sample_weight=sample_weight)
348+
params = process_routing(self, "fit", sample_weight=sample_weight, **fit_params)
349349
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
350350
return self
351351

352352
def predict(self, X, **predict_params):
353-
params = process_routing(self, "predict", predict_params)
353+
params = process_routing(self, "predict", **predict_params)
354354
return self.estimator_.predict(X, **params.estimator.predict)
355355

356356
def get_metadata_routing(self):
@@ -374,7 +374,7 @@ def fit(self, X, y, sample_weight=None, **kwargs):
374374
self.registry.append(self)
375375

376376
record_metadata(self, "fit", sample_weight=sample_weight)
377-
params = process_routing(self, "fit", kwargs, sample_weight=sample_weight)
377+
params = process_routing(self, "fit", sample_weight=sample_weight, **kwargs)
378378
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
379379
return self
380380

@@ -394,12 +394,12 @@ def __init__(self, transformer):
394394
self.transformer = transformer
395395

396396
def fit(self, X, y=None, **fit_params):
397-
params = process_routing(self, "fit", fit_params)
397+
params = process_routing(self, "fit", **fit_params)
398398
self.transformer_ = clone(self.transformer).fit(X, y, **params.transformer.fit)
399399
return self
400400

401401
def transform(self, X, y=None, **transform_params):
402-
params = process_routing(self, "transform", transform_params)
402+
params = process_routing(self, "transform", **transform_params)
403403
return self.transformer_.transform(X, **params.transformer.transform)
404404

405405
def get_metadata_routing(self):

sklearn/tests/test_metadata_routing.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(self, steps):
7474

7575
def fit(self, X, y, **fit_params):
7676
self.steps_ = []
77-
params = process_routing(self, "fit", fit_params)
77+
params = process_routing(self, "fit", **fit_params)
7878
X_transformed = X
7979
for i, step in enumerate(self.steps[:-1]):
8080
transformer = clone(step).fit(
@@ -93,7 +93,7 @@ def fit(self, X, y, **fit_params):
9393
def predict(self, X, **predict_params):
9494
check_is_fitted(self)
9595
X_transformed = X
96-
params = process_routing(self, "predict", predict_params)
96+
params = process_routing(self, "predict", **predict_params)
9797
for i, step in enumerate(self.steps_[:-1]):
9898
X_transformed = step.transform(X, **params.get(f"step_{i}").transform)
9999

@@ -230,15 +230,15 @@ class OddEstimator(BaseEstimator):
230230

231231
def test_process_routing_invalid_method():
232232
with pytest.raises(TypeError, match="Can only route and process input"):
233-
process_routing(ConsumingClassifier(), "invalid_method", {})
233+
process_routing(ConsumingClassifier(), "invalid_method", **{})
234234

235235

236236
def test_process_routing_invalid_object():
237237
class InvalidObject:
238238
pass
239239

240240
with pytest.raises(AttributeError, match="has not implemented the routing"):
241-
process_routing(InvalidObject(), "fit", {})
241+
process_routing(InvalidObject(), "fit", **{})
242242

243243

244244
def test_simple_metadata_routing():

sklearn/utils/_metadata_requests.py

+16-25
Original file line numberDiff line numberDiff line change
@@ -1412,34 +1412,33 @@ def get_metadata_routing(self):
14121412
# given metadata. This is to minimize the boilerplate required in routers.
14131413

14141414

1415-
def process_routing(obj, method, other_params, **kwargs):
1415+
# Here the first two arguments are positional only which makes everything
1416+
# passed as keyword argument a metadata. The first two args also have an `_`
1417+
# prefix to reduce the chances of name collisions with the passed metadata, and
1418+
# since they're positional only, users will never type those underscores.
1419+
def process_routing(_obj, _method, /, **kwargs):
14161420
"""Validate and route input parameters.
14171421
14181422
This function is used inside a router's method, e.g. :term:`fit`,
14191423
to validate the metadata and handle the routing.
14201424
14211425
Assuming this signature: ``fit(self, X, y, sample_weight=None, **fit_params)``,
14221426
a call to this function would be:
1423-
``process_routing(self, fit_params, sample_weight=sample_weight)``.
1427+
``process_routing(self, sample_weight=sample_weight, **fit_params)``.
14241428
14251429
.. versionadded:: 1.3
14261430
14271431
Parameters
14281432
----------
1429-
obj : object
1433+
_obj : object
14301434
An object implementing ``get_metadata_routing``. Typically a
14311435
meta-estimator.
14321436
1433-
method : str
1437+
_method : str
14341438
The name of the router's method in which this function is called.
14351439
1436-
other_params : dict
1437-
A dictionary of extra parameters passed to the router's method,
1438-
e.g. ``**fit_params`` passed to a meta-estimator's :term:`fit`.
1439-
14401440
**kwargs : dict
1441-
Parameters explicitly accepted and included in the router's method
1442-
signature.
1441+
Metadata to be routed.
14431442
14441443
Returns
14451444
-------
@@ -1449,27 +1448,19 @@ def process_routing(obj, method, other_params, **kwargs):
14491448
corresponding methods or corresponding child objects. The object names
14501449
are those defined in `obj.get_metadata_routing()`.
14511450
"""
1452-
if not hasattr(obj, "get_metadata_routing"):
1451+
if not hasattr(_obj, "get_metadata_routing"):
14531452
raise AttributeError(
1454-
f"This {repr(obj.__class__.__name__)} has not implemented the routing"
1453+
f"This {repr(_obj.__class__.__name__)} has not implemented the routing"
14551454
" method `get_metadata_routing`."
14561455
)
1457-
if method not in METHODS:
1456+
if _method not in METHODS:
14581457
raise TypeError(
14591458
f"Can only route and process input on these methods: {METHODS}, "
1460-
f"while the passed method is: {method}."
1459+
f"while the passed method is: {_method}."
14611460
)
14621461

1463-
# We take the extra params (**fit_params) which is passed as `other_params`
1464-
# and add the explicitly passed parameters (passed as **kwargs) to it. This
1465-
# is equivalent to a code such as this in a router:
1466-
# if sample_weight is not None:
1467-
# fit_params["sample_weight"] = sample_weight
1468-
all_params = other_params if other_params is not None else dict()
1469-
all_params.update(kwargs)
1470-
1471-
request_routing = get_routing_for_object(obj)
1472-
request_routing.validate_metadata(params=all_params, method=method)
1473-
routed_params = request_routing.route_params(params=all_params, caller=method)
1462+
request_routing = get_routing_for_object(_obj)
1463+
request_routing.validate_metadata(params=kwargs, method=_method)
1464+
routed_params = request_routing.route_params(params=kwargs, caller=_method)
14741465

14751466
return routed_params

0 commit comments

Comments
 (0)
0