8000 FIX Allows pipeline to pass through feature names (#21351) · scikit-learn/scikit-learn@a3f09ea · GitHub
[go: up one dir, main page]

Skip to content

Commit a3f09ea

Browse files
authored
FIX Allows pipeline to pass through feature names (#21351)
1 parent a343963 commit a3f09ea

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

doc/whats_new/v1.0.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ Fixed models
6060
and `radius_neighbors`, due to handling of explicit zeros in `bsr` and `dok`
6161
:term:`sparse graph` formats. :pr:`21199` by `Thomas Fan`_.
6262

63+
:mod:`sklearn.pipeline`
64+
.......................
65+
66+
- |Fix| :meth:`pipeline.Pipeline.get_feature_names_out` correctly passes feature
67+
names out from one step of a pipeline to the next. :pr:`21351` by
68+
`Thomas Fan`_.
69+
6370
.. _changes_1_0:
6471

6572
Version 1.0.0

sklearn/pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -746,15 +746,16 @@ def get_feature_names_out(self, input_features=None):
746746
feature_names_out : ndarray of str objects
747747
Transformed feature names.
748748
"""
749+
feature_names_out = input_features
749750
for _, name, transform in self._iter():
750751
if not hasattr(transform, "get_feature_names_out"):
751752
raise AttributeError(
752753
"Estimator {} does not provide get_feature_names_out. "
753754
"Did you mean to call pipeline[:-1].get_feature_names_out"
754755
"()?".format(name)
755756
)
756-
feature_names = transform.get_feature_names_out(input_features)
757-
return feature_names
757+
feature_names_out = transform.get_feature_names_out(feature_names_out)
758+
return feature_names_out
758759

759760
@property
760761
def n_features_in_(self):

sklearn/tests/test_pipeline.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,3 +1517,24 @@ def fit(self, X, y):
15171517
check_is_fitted(pipeline)
15181518
pipeline.fit(iris.data, iris.target)
15191519
check_is_fitted(pipeline)
1520+
1521+
1522+
def test_pipeline_get_feature_names_out_passes_names_through():
1523+
"""Check that pipeline passes names through.
1524+
1525+
Non-regresion test for #21349.
1526+
"""
1527+
X, y = iris.data, iris.target
1528+
1529+
class AddPrefixStandardScalar(StandardScaler):
1530+
def get_feature_names_out(self, input_features=None):
1531+
names = super().get_feature_names_out(input_features=input_features)
1532+
return np.asarray([f"my_prefix_{name}" for name in names], dtype=object)
1533+
1534+
pipe = make_pipeline(AddPrefixStandardScalar(), StandardScaler())
1535+
pipe.fit(X, y)
1536+
1537+
input_names = iris.feature_names
1538+
feature_names_out = pipe.get_feature_names_out(input_names)
1539+
1540+
assert_array_equal(feature_names_out, [f"my_prefix_{name}" for name in input_names])

0 commit comments

Comments
 (0)
0