8000 FIX make pipeline pass check_estimator (#26325) · scikit-learn/scikit-learn@8521819 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8521819

Browse files
authored
FIX make pipeline pass check_estimator (#26325)
1 parent faf9b1f commit 8521819

File tree

4 files changed

+49
-14
lines changed

4 files changed

+49
-14
lines changed

doc/whats_new/v1.3.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,10 @@ Changelog
467467
`pandas.DataFrame`
468468
:pr:`25220` by :user:`Ian Thompson <it176131>`.
469469

470+
- |Fix| :meth:`pipeline.Pipeline.fit_transform` now raises an `AttributeError`
471+
if the last step of the pipeline does not support `fit_transform`.
472+
:pr:`26325` by `Adrin Jalali`_.
473+
470474
:mod:`sklearn.preprocessing`
471475
............................
472476

sklearn/pipeline.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,14 @@ def fit(self, X, y=None, **fit_params):
421421

422422
return self
423423

424+
def _can_fit_transform(self):
425+
return (
426+
self._final_estimator == "passthrough"
427+
or hasattr(self._final_estimator, "transform")
428+
or hasattr(self._final_estimator, "fit_transform")
429+
)
430+
431+
@available_if(_can_fit_transform)
424432
def fit_transform(self, X, y=None, **fit_params):
425433
"""Fit the model and transform with the final estimator.
426434
@@ -744,12 +752,34 @@ def classes_(self):
744752
return self.steps[-1][1].classes_
745753

746754
def _more_tags(self):
755+
tags = {
756+
"_xfail_checks": {
757+
"check_dont_overwrite_parameters": (
758+
"Pipeline changes the `steps` parameter, which it shouldn't."
759+
"Therefore this test is x-fail until we fix this."
760+
),
761+
"check_estimators_overwrite_params": (
762+
"Pipeline changes the `steps` parameter, which it shouldn't."
763+
"Therefore this test is x-fail until we fix this."
764+
),
765+
}
766+
}
767+
747768
try:
748-
return {"pairwise": _safe_tags(self.steps[0][1], "pairwise")}
769+
tags["pairwise"] = _safe_tags(self.steps[0][1], "pairwise")
749770
except (ValueError, AttributeError, TypeError):
750771
# This happens when the `steps` is not a list of (name, estimator)
751772
# tuples and `fit` is not called yet to validate the steps.
752-
return {}
773+
pass
774+
775+
try:
776+
tags["multioutput"] = _safe_tags(self.steps[-1][1], "multioutput")
777+
except (ValueError, AttributeError, TypeError):
778+
# This happens when the `steps` is not a list of (name, estimator)
779+
# tuples and `fit` is not called yet to validate the steps.
780+
pass
781+
782+
return tags
753783

754784
def get_feature_names_out(self, input_features=None):
755785
"""Get output feature names for transformation.

sklearn/tests/test_common.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,17 @@ def _tested_estimators(type_filter=None):
133133
yield estimator
134134

135135

136-
@parametrize_with_checks(list(_tested_estimators()))
136+
def _generate_pipeline():
137+
for final_estimator in [Ridge(), LogisticRegression()]:
138+
yield Pipeline(
139+
steps=[
140+
("scaler", StandardScaler()),
141+
("final_estimator", final_estimator),
142+
]
143+
)
144+
145+
146+
@parametrize_with_checks(list(chain(_tested_estimators(), _generate_pipeline())))
137147
def test_estimators(estimator, check, request):
138148
# Common tests for estimator instances
139149
with ignore_warnings(category=(FutureWarning, ConvergenceWarning, UserWarning)):
@@ -283,16 +293,6 @@ def _generate_column_transformer_instances():
283293
)
284294

285295

286-
def _generate_pipeline():
287-
for final_estimator in [Ridge(), LogisticRegression()]:
288-
yield Pipeline(
289-
steps=[
290-
("scaler", StandardScaler()),
291-
("final_estimator", final_estimator),
292-
]
293-
)
294-
295-
296296
def _generate_search_cv_instances():
297297
for SearchCV, (Estimator, param_grid) in product(
298298
[

sklearn/tests/test_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,8 @@ def test_set_pipeline_steps():
670670
with pytest.raises(TypeError, match=msg):
671671
pipeline.fit([[1]], [1])
672672

673-
with pytest.raises(TypeError, match=msg):
673+
msg = "This 'Pipeline' has no attribute 'fit_transform'"
674+
with pytest.raises(AttributeError, match=msg):
674675
pipeline.fit_transform([[1]], [1])
675676

676677

0 commit comments

Comments
 (0)
0