[go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNT SLEP6: raise NotImplementedError for meta-estimators not supporting metadata routing #27389

Merged
merged 18 commits into from
Sep 25, 2023
Merged
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
FIX pipeline shouldn't call process_routing unless routing is enabled
  • Loading branch information
adrinjalali committed Sep 17, 2023
commit 86d5f181be1778d7cb87e7a2f6feca51f81e18a4
27 changes: 18 additions & 9 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,12 +750,15 @@ def decision_function(self, X, **params):
Result of calling `decision_function` on the final estimator.
"""
_raise_for_params(params, self, "decision_function")
Xt = X

if not _routing_enabled():
for _, name, transform in self._iter(with_final=False):
Xt = transform.transform(Xt)
return self.steps[-1][1].decision_function(Xt)

# not branching here since params is only available if
# enable_metadata_routing=True
routed_params = process_routing(self, "decision_function", **params)

Xt = X
for _, name, transform in self._iter(with_final=False):
Xt = transform.transform(
Xt, **routed_params.get(name, {}).get("transform", {})
Expand Down Expand Up @@ -886,10 +889,13 @@ def transform(self, X, **params):
"""
_raise_for_params(params, self, "transform")

# not branching here since params is only available if
# enable_metadata_routing=True
routed_params = process_routing(self, "transform", **params)
Xt = X
if not _routing_enabled():
for _, name, transform in self._iter():
Xt = transform.transform(Xt)
return Xt

routed_params = process_routing(self, "transform", **params)
for _, name, transform in self._iter():
Xt = transform.transform(Xt, **routed_params[name].transform)
return Xt
Expand Down Expand Up @@ -928,11 +934,14 @@ def inverse_transform(self, Xt, **params):
space.
"""
_raise_for_params(params, self, "inverse_transform")
reverse_iter = reversed(list(self._iter()))

if not _routing_enabled():
for _, name, transform in reverse_iter:
Xt = transform.inverse_transform(Xt)
return Xt

# we don't have to branch here, since params is only non-empty if
# enable_metadata_routing=True.
routed_params = process_routing(self, "inverse_transform", **params)
reverse_iter = reversed(list(self._iter()))
for _, name, transform in reverse_iter:
Xt = transform.inverse_transform(
Xt, **routed_params[name].inverse_transform
Expand Down