10000 ENH ColumnTransformer.get_feature_names() handles passthrough (#14048) · scikit-learn/scikit-learn@670b85c · GitHub
[go: up one dir, main page]

Skip to content

Commit 670b85c

Browse files
authored
ENH ColumnTransformer.get_feature_names() handles passthrough (#14048)
1 parent a0e6b95 commit 670b85c

File tree

3 files changed

+97
-22
lines changed

3 files changed

+97
-22
lines changed

doc/whats_new/v0.23.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ Changelog
105105
a column name that is not unique in the dataframe. :pr:`16431` by
106106
`Thomas Fan`_.
107107

108+
- |Enhancement| :class:`compose.ColumnTransformer` method ``get_feature_names``
109+
now supports `'passthrough'` columns, with the feature name being either
110+
the column name for a dataframe, or `'xi'` for column index `i`.
111+
:pr:`14048` by :user:`Lewis Ball <lrjball>`.
112+
108113
:mod:`sklearn.datasets`
109114
.......................
110115

sklearn/compose/_column_transformer.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -315,19 +315,18 @@ def _validate_remainder(self, X):
315315
self.remainder)
316316

317317
# Make it possible to check for reordered named columns on transform
318-
if (hasattr(X, 'columns') and
319-
any(_determine_key_type(cols) == 'str'
320-
for cols in self._columns)):
318+
self._has_str_cols = any(_determine_key_type(cols) == 'str'
319+
for cols in self._columns)
320+
if hasattr(X, 'columns'):
321321
self._df_columns = X.columns
322322

323323
self._n_features = X.shape[1]
324324
cols = []
325325
for columns in self._columns:
326326
cols.extend(_get_column_indices(X, columns))
327-
remaining_idx = list(set(range(self._n_features)) - set(cols))
328-
remaining_idx = sorted(remaining_idx) or None
329327

330-
self._remainder = ('remainder', self.remainder, remaining_idx)
328+
remaining_idx = sorted(set(range(self._n_features)) - set(cols))
329+
self._remainder = ('remainder', self.remainder, remaining_idx or None)
331330

332331
@property
333332
def named_transformers_(self):
@@ -356,11 +355,18 @@ def get_feature_names(self):
356355
if trans == 'drop' or (
357356
hasattr(column, '__len__') and not len(column)):
358357
continue
359-
elif trans == 'passthrough':
360-
raise NotImplementedError(
361-
"get_feature_names is not yet supported when using "
362-
"a 'passthrough' transformer.")
363-
elif not hasattr(trans, 'get_feature_names'):
358+
if trans == 'passthrough':
359+
if hasattr(self, '_df_columns'):
360+
if ((not isinstance(column, slice))
361+
and all(isinstance(col, str) for col in column)):
362+
feature_names.extend(column)
363+
else:
364+
feature_names.extend(self._df_columns[column])
365+
else:
366+
indices = np.arange(self._n_features)
367+
feature_names.extend(['x%d' % i for i in indices[column]])
368+
continue
369+
if not hasattr(trans, 'get_feature_names'):
364370
raise AttributeError("Transformer %s (type %s) does not "
365371
"provide get_feature_names."
366372
% (str(name), type(trans).__name__))
@@ -582,6 +588,7 @@ def transform(self, X):
582588
# name order and count. See #14237 for details.
583589
if (self._remainder[2] is not None and
584590
hasattr(self, '_df_columns') and
591+
self._has_str_cols and
585592
hasattr(X, 'columns')):
586593
n_cols_fit = len(self._df_columns)
587594
n_cols_transform = len(X.columns)

sklearn/compose/tests/test_column_transformer.py

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -668,25 +668,88 @@ def test_column_transformer_get_feature_names():
668668
ct.fit(X)
669669
assert ct.get_feature_names() == ['col0__a', 'col0__b', 'col1__c']
670670

671-
# passthrough transformers not supported
671+
# drop transformer
672+
ct = ColumnTransformer(
673+
[('col0', DictVectorizer(), 0), ('col1', 'drop', 1)])
674+
ct.fit(X)
675+
assert ct.get_feature_names() == ['col0__a', 'col0__b']
676+
677+
# passthrough transformer
672678
ct = ColumnTransformer([('trans', 'passthrough', [0, 1])])
673679
ct.fit(X)
674-
assert_raise_message(
675-
NotImplementedError, 'get_feature_names is not yet supported',
676-
ct.get_feature_names)
680+
assert ct.get_feature_names() == ['x0', 'x1']
677681

678682
ct = ColumnTransformer([('trans', DictVectorizer(), 0)],
679683
remainder='passthrough')
680684
ct.fit(X)
681-
assert_raise_message(
682-
NotImplementedError, 'get_feature_names is not yet supported',
683-
ct.get_feature_names)
685+
assert ct.get_feature_names() == ['trans__a', 'trans__b', 'x1']
684686

685-
# drop transformer
686-
ct = ColumnTransformer(
687-
[('col0', DictVectorizer(), 0), ('col1', 'drop', 1)])
687+
ct = ColumnTransformer([('trans', 'passthrough', [1])],
688+
remainder='passthrough')
688689
ct.fit(X)
689-
assert ct.get_feature_names() == ['col0__a', 'col0__b']
690+
assert ct.get_feature_names() == ['x1', 'x0']
691+
692+
ct = ColumnTransformer([('trans', 'passthrough', lambda x: [1])],
693+
remainder='passthrough')
694+
ct.fit(X)
695+
assert ct.get_feature_names() == ['x1', 'x0']
696+
697+
ct = ColumnTransformer([('trans', 'passthrough', np.array([False, True]))],
698+
remainder='passthrough')
699+
ct.fit(X)
700+
assert ct.get_feature_names() == ['x1', 'x0']
701+
702+
ct = ColumnTransformer([('trans', 'passthrough', slice(1, 2))],
703+
remainder='passthrough')
704+
ct.fit(X)
705+
assert ct.get_feature_names() == ['x1', 'x0']
706+
707+
708+
def test_column_transformer_get_feature_names_dataframe():
709+
# passthough transformer with a dataframe
710+
pd = pytest.importorskip('pandas')
711+
X = np.array([[{'a': 1, 'b': 2}, {'a': 3, 'b': 4}],
712+
[{'c': 5}, {'c': 6}]], dtype=object).T
713+
X_df = pd.DataFrame(X, columns=['col0', 'col1'])
714+
715+
ct = ColumnTransformer([('trans', 'passthrough', ['col0', 'col1'])])
716+
ct.fit(X_df)
717+
assert ct.get_feature_names() == ['col0', 'col1']
718+
719+
ct = ColumnTransformer([('trans', 'passthrough', [0, 1])])
720+
ct.fit(X_df)
721+
assert ct.get_feature_names() == ['col0', 'col1']
722+
723+
ct = ColumnTransformer([('col0', DictVectorizer(), 0)],
724+
remainder='passthrough')
725+
ct.fit(X_df)
726+
assert ct.get_feature_names() == ['col0__a', 'col0__b', 'col1']
727+
728+
ct = ColumnTransformer([('trans', 'passthrough', ['col1'])],
729+
remainder='passthrough')
730+
ct.fit(X_df)
731+
assert ct.get_feature_names() == ['col1', 'col0']
732+
733+
ct = ColumnTransformer([('trans', 'passthrough',
734+
lambda x: x[['col1']].columns)],
735+
remainder='passthrough')
736+
ct.fit(X_df)
737+
assert ct.get_feature_names() == ['col1', 'col0']
738+
739+
ct = ColumnTransformer([('trans', 'passthrough', np.array([False, True]))],
740+
remainder='passthrough')
741+
ct.fit(X_df)
742+
assert ct.get_feature_names() == ['col1', 'col0']
743+
744+
ct = ColumnTransformer([('trans', 'passthrough', slice(1, 2))],
745+
remainder='passthrough')
746+
ct.fit(X_df)
747+
assert ct.get_feature_names() == ['col1', 'col0']
748+
749+
ct = ColumnTransformer([('trans', 'passthrough', [1])],
750+
remainder='passthrough')
751+
ct.fit(X_df)
752+
assert ct.get_feature_names() == ['col1', 'col0']
690753

691754

692755
def test_column_transformer_special_strings():

0 commit comments

Comments
 (0)
0