8000 FIX (SLEP6) make Pipeline work with an estimator implementing __len__… · scikit-learn/scikit-learn@594475a · GitHub
[go: up one dir, main page]

Skip to content

Commit 594475a

Browse files
authored
FIX (SLEP6) make Pipeline work with an estimator implementing __len__ (#26964)
1 parent 9c96671 commit 594475a

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

sklearn/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,7 +1106,7 @@ def get_metadata_routing(self):
11061106
router = MetadataRouter(owner=self.__class__.__name__)
11071107< 8000 /td>

11081108
# first we add all steps except the last one
1109-
for _, name, trans in self._iter(with_final=False):
1109+
for _, name, trans in self._iter(with_final=False, filter_passthrough=True):
11101110
method_mapping = MethodMapping()
11111111
# fit, fit_predict, and fit_transform call fit_transform if it
11121112
# exists, or else fit and transform
@@ -1140,7 +1140,7 @@ def get_metadata_routing(self):
11401140
router.add(method_mapping=method_mapping, **{name: trans})
11411141

11421142
final_name, final_est = self.steps[-1]
1143-
if not final_est:
1143+
if final_est is None or final_est == "passthrough":
11441144
return router
11451145

11461146
# then we add the last step

sklearn/tests/test_pipeline.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
from sklearn.datasets import load_iris
1818
from sklearn.decomposition import PCA, TruncatedSVD
1919
from sklearn.dummy import DummyRegressor
20-
from sklearn.ensemble import HistGradientBoostingClassifier
20+
from sklearn.ensemble import (
21+
HistGradientBoostingClassifier,
22+
RandomForestClassifier,
23+
RandomTreesEmbedding,
24+
)
2125
from sklearn.exceptions import NotFittedError
2226
from sklearn.feature_extraction.text import CountVectorizer
2327
from sklearn.feature_selection import SelectKBest, f_classif
@@ -27,7 +31,7 @@
2731
from sklearn.model_selection import train_test_split
2832
from sklearn.neighbors import LocalOutlierFactor
2933
from sklearn.pipeline import FeatureUnion, Pipeline, make_pipeline, make_union
30-
from sklearn.preprocessing import StandardScaler
34+
from sklearn.preprocessing import FunctionTransformer, StandardScaler
3135
from sklearn.svm import SVC
3236
from sklearn.utils._metadata_requests import COMPOSITE_METHODS, METHODS
3337
from sklearn.utils._testing import (
@@ -1828,5 +1832,26 @@ def test_routing_passed_metadata_not_supported(method):
18281832
getattr(pipe, method)([[1]], sample_weight=[1], prop="a")
18291833

18301834

1835+
@pytest.mark.usefixtures("enable_slep006")
1836+
def test_pipeline_with_estimator_with_len():
1837+
"""Test that pipeline works with estimators that have a `__len__` method."""
1838+
pipe = Pipeline(
1839+
[("trs", RandomTreesEmbedding()), ("estimator", RandomForestClassifier())]
1840+
)
1841+
pipe.fit([[1]], [1])
1842+
pipe.predict([[1]])
1843+
1844+
1845+
@pytest.mark.usefixtures("enable_slep006")
1846+
@pytest.mark.parametrize("last_step", [None, "passthrough"])
1847+
def test_pipeline_with_no_last_step(last_step):
1848+
"""Test that the pipeline works when there is not last step.
1849+
1850+
It should just ignore and pass through the data on transform.
1851+
"""
1852+
pipe = Pipeline([("trs", FunctionTransformer()), ("estimator", last_step)])
1853+
assert pipe.fit([[1]], [1]).transform([[1], [2], [3]]) == [[1], [2], [3]]
1854+
1855+
18311856
# End of routing tests
18321857
# ====================

0 commit comments

Comments
 (0)
0