8000 MNT (SLEP6) remove other_params from provess_routing by adrinjalali · Pull Request #26909 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ Changelog
- |Enhancement| :func:`base.clone` now supports `dict` as input and creates a
copy. :pr:`26786` by `Adrin Jalali`_.

- |API|:func:`~utils.metadata_routing.process_routing` now has a different
signature. The first two (the object and the method) are positional only,
and all metadata are passed as keyword arguments. :pr:`26909` by `Adrin
Jalali`_.

:mod:`sklearn.cross_decomposition`
..................................

Expand Down
8 changes: 4 additions & 4 deletions examples/miscellaneous/plot_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def get_metadata_routing(self):
return router

def fit(self, X, y, **fit_params):
params = process_routing(self, "fit", fit_params)
params = process_routing(self, "fit", **fit_params)

self.transformer_ = clone(self.transformer).fit(X, y, **params.transformer.fit)
X_transformed = self.transformer_.transform(X, **params.transformer.transform)
Expand All @@ -458,7 +458,7 @@ def fit(self, X, y, **fit_params):
return self

def predict(self, X, **predict_params):
params = process_routing(self, "predict", predict_params)
params = process_routing(self, "predict", **predict_params)

X_transformed = self.transformer_.transform(X, **params.transformer.transform)
return self.classifier_.predict(X_transformed, **params.classifier.predict)
Expand Down Expand Up @@ -543,7 +543,7 @@ def __init__(self, estimator):
self.estimator = estimator

def fit(self, X, y, **fit_params):
params = process_routing(self, "fit", fit_params)
params = process_routing(self, "fit", **fit_params)
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)

def get_metadata_routing(self):
Expand Down Expand Up @@ -572,7 +572,7 @@ def __init__(self, estimator):
self.estimator = estimator

def fit(self, X, y, sample_weight=None, **fit_params):
params = process_routing(self, "fit", fit_params, sample_weight=sample_weight)
params = process_routing(self, "fit", sample_weight=sample_weight, **fit_params)
check_metadata(self, sample_weight=sample_weight)
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)

Expand Down
6 changes: 3 additions & 3 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,10 @@ def fit(self, X, y, sample_weight=None, **fit_params):

if _routing_enabled():
routed_params = process_routing(
obj=self,
method="fit",
self,
"fit",
sample_weight=sample_weight,
other_params=fit_params,
**fit_params,
)
else:
# sample_weight checks
Expand Down
12 changes: 6 additions & 6 deletions sklearn/linear_model/_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1859,10 +1859,10 @@ def fit(self, X, y, sample_weight=None, **params):

if _routing_enabled():
routed_params = process_routing(
obj=self,
method="fit",
self,
"fit",
sample_weight=sample_weight,
other_params=params,
**params,
)
else:
routed_params = Bunch()
Expand Down Expand Up @@ -2150,10 +2150,10 @@ def score(self, X, y, sample_weight=None, **score_params):
scoring = self._get_scorer()
if _routing_enabled():
routed_params = process_routing(
obj=self,
method="score",
self,
"score",
sample_weight=sample_weight,
other_params=score_params,
**score_params,
)
else:
routed_params = Bunch()
Expand Down
2 changes: 1 addition & 1 deletion sklearn/metrics/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __call__(self, estimator, *args, **kwargs):
cached_call = partial(_cached_call, cache)

if _routing_enabled():
routed_params = process_routing(self, "score", kwargs)
routed_params = process_routing(self, "score", **kwargs)
else:
# they all get the same args, and they all get them all
routed_params = Bunch(
Expand Down
16 changes: 7 additions & 9 deletions sklearn/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ def partial_fit(self, X, y, classes=None, sample_weight=None, **partial_fit_para

if _routing_enabled():
routed_params = process_routing(
obj=self,
method="partial_fit",
other_params=partial_fit_params,
self,
"partial_fit",
sample_weight=sample_weight,
**partial_fit_params,
)
else:
if sample_weight is not None and not has_fit_parameter(
Expand Down Expand Up @@ -249,10 +249,10 @@ def fit(self, X, y, sample_weight=None, **fit_params):

if _routing_enabled():
routed_params = process_routing(
obj=self,
method="fit",
other_params=fit_params,
self,
"fit",
sample_weight=sample_weight,
**fit_params,
)
else:
if sample_weight is not None and not has_fit_parameter(
Expand Down Expand Up @@ -706,9 +706,7 @@ def fit(self, X, Y, **fit_params):
del Y_pred_chain

if _routing_enabled():
routed_params = process_routing(
obj=self, method="fit", other_params=fit_params
)
routed_params = process_routing(self, "fit", **fit_params)
else:
routed_params = Bunch(estimator=Bunch(fit=fit_params))

Expand Down
18 changes: 8 additions & 10 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,7 @@ def _log_message(self, step_idx):

def _check_method_params(self, method, props, **kwargs):
if _routing_enabled():
routed_params = process_routing(
self, method=method, other_params=props, **kwargs
)
routed_params = process_routing(self, method, **props, **kwargs)
return routed_params
else:
fit_params_steps = Bunch(
Expand Down Expand Up @@ -586,7 +584,7 @@ def predict(self, X, **params):
return self.steps[-1][1].predict(Xt, **params)

# metadata routing enabled
routed_params = process_routing(self, "predict", other_params=params)
routed_params = process_routing(self, "predict", **params)
for _, name, transform in self._iter(with_final=False):
Xt = transform.transform(Xt, **routed_params[name].transform)
return self.steps[-1][1].predict(Xt, **routed_params[self.steps[-1][0]].predict)
Expand Down Expand Up @@ -706,7 +704,7 @@ def predict_proba(self, X, **params):
return self.steps[-1][1].predict_proba(Xt, **params)

# metadata routing enabled
routed_params = process_routing(self, "predict_proba", other_params=params)
routed_params = process_routing(self, "predict_proba", **params)
for _, name, transform in self._iter(with_final=False):
Xt = transform.transform(Xt, **routed_params[name].transform)
return self.steps[-1][1].predict_proba(
Expand Down Expand Up @@ -747,7 +745,7 @@ def decision_function(self, X, **params):

# not branching here since params is only available if
# enable_metadata_routing=True
routed_params = process_routing(self, "decision_function", other_params=params)
routed_params = process_routing(self, "decision_function", **params)

Xt = X
for _, name, transform in self._iter(with_final=False):
Expand Down Expand Up @@ -833,7 +831,7 @@ def predict_log_proba(self, X, **params):
return self.steps[-1][1].predict_log_proba(Xt, **params)

# metadata routing enabled
routed_params = process_routing(self, "predict_log_proba", other_params=params)
routed_params = process_routing(self, "predict_log_proba", **params)
for _, name, transform in self._iter(with_final=False):
Xt = transform.transform(Xt, **routed_params[name].transform)
return self.steps[-1][1].predict_log_proba(
Expand Down Expand Up @@ -882,7 +880,7 @@ def transform(self, X, **params):

# not branching here since params is only available if
# enable_metadata_routing=True
routed_params = process_routing(self, "transform", other_params=params)
routed_params = process_routing(self, "transform", **params)
Xt = X
for _, name, transform in self._iter():
Xt = transform.transform(Xt, **routed_params[name].transform)
Expand Down Expand Up @@ -925,7 +923,7 @@ def inverse_transform(self, Xt, **params):

# we don't have to branch here, since params is only non-empty if
# enable_metadata_routing=True.
routed_params = process_routing(self, "inverse_transform", other_params=params)
routed_params = process_routing(self, "inverse_transform", **params)
reverse_iter = reversed(list(self._iter()))
for _, name, transform in reverse_iter:
Xt = transform.inverse_transform(
Expand Down Expand Up @@ -981,7 +979,7 @@ def score(self, X, y=None, sample_weight=None, **params):

# metadata routing is enabled.
routed_params = process_routing(
self, "score", sample_weight=sample_weight, other_params=params
self, "score", sample_weight=sample_weight, **params
)

Xt = X
Expand Down
12 changes: 6 additions & 6 deletions sklearn/tests/metadata_routing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def __init__(self, estimator):
self.estimator = estimator

def fit(self, X, y, **fit_params):
params = process_routing(self, "fit", fit_params)
params = process_routing(self, "fit", **fit_params)
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)

def get_metadata_routing(self):
Expand All @@ -345,12 +345,12 @@ def fit(self, X, y, sample_weight=None, **fit_params):
self.registry.append(self)

record_metadata(self, "fit", sample_weight=sample_weight)
params = process_routing(self, "fit", fit_params, sample_weight=sample_weight)
params = process_routing(self, "fit", sample_weight=sample_weight, **fit_params)
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
return self

def predict(self, X, **predict_params):
params = process_routing(self, "predict", predict_params)
params = process_routing(self, "predict", **predict_params)
return self.estimator_.predict(X, **params.estimator.predict)

def get_metadata_routing(self):
Expand All @@ -374,7 +374,7 @@ def fit(self, X, y, sample_weight=None, **kwargs):
self.registry.append(self)

record_metadata(self, "fit", sample_weight=sample_weight)
params = process_routing(self, "fit", kwargs, sample_weight=sample_weight)
params = process_routing(self, "fit", sample_weight=sample_weight, **kwargs)
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit)
return self

Expand All @@ -394,12 +394,12 @@ def __init__(self, transformer):
self.transformer = transformer

def fit(self, X, y=None, **fit_params):
params = process_routing(self, "fit", fit_params)
params = process_routing(self, "fit", **fit_params)
self.transformer_ = clone(self.transformer).fit(X, y, **params.transformer.fit)
return self

def transform(self, X, y=None, **transform_params):
params = process_routing(self, "transform", transform_params)
params = process_routing(self, "transform", **transform_params)
return self.transformer_.transform(X, **params.transformer.transform)

def get_metadata_routing(self):
Expand Down
8 changes: 4 additions & 4 deletions sklearn/tests/test_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, steps):

def fit(self, X, y, **fit_params):
self.steps_ = []
params = process_routing(self, "fit", fit_params)
params = process_routing(self, "fit", **fit_params)
X_transformed = X
for i, step in enumerate(self.steps[:-1]):
transformer = clone(step).fit(
Expand All @@ -93,7 +93,7 @@ def fit(self, X, y, **fit_params):
def predict(self, X, **predict_params):
check_is_fitted(self)
X_transformed = X
params = process_routing(self, "predict", predict_params)
params = process_routing(self, "predict", **predict_params)
for i, step in enumerate(self.steps_[:-1]):
X_transformed = step.transform(X, **params.get(f"step_{i}").transform)

Expand Down Expand Up @@ -230,15 +230,15 @@ class OddEstimator(BaseEstimator):

def test_process_routing_invalid_method():
with pytest.raises(TypeError, match="Can only route and process input"):
process_routing(ConsumingClassifier(), "invalid_method", {})
process_routing(ConsumingClassifier(), "invalid_method", **{})


def test_process_routing_invalid_object():
class InvalidObject:
pass

with pytest.raises(AttributeError, match="has not implemented the routing"):
process_routing(InvalidObject(), "fit", {})
process_routing(InvalidObject(), "fit", **{})


def test_simple_metadata_routing():
Expand Down
41 changes: 16 additions & 25 deletions sklearn/utils/_metadata_requests.py