8000 CLN Removes duplicated or unneeded code in ColumnTransformer (#19261) · scikit-learn/scikit-learn@8965abb · GitHub
[go: up one dir, main page]

Skip to content

Commit 8965abb

Browse files
authored
CLN Removes duplicated or unneeded code in ColumnTransformer (#19261)
1 parent 5ef9fa4 commit 8965abb

File tree

1 file changed

+5
-12
lines changed

1 file changed

+5
-12
lines changed

sklearn/compose/_column_transformer.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from ..utils import Bunch
2020
from ..utils import _safe_indexing
2121
from ..utils import _get_column_indices
22-
from ..utils import _determine_key_type
2322
from ..utils.metaestimators import _BaseComposition
2423
from ..utils.validation import check_array, check_is_fitted
2524
from ..utils.validation import _deprecate_positional_args
@@ -320,12 +319,6 @@ def _validate_remainder(self, X):
320319
"'passthrough', or estimator. '%s' was passed instead" %
321320
self.remainder)
322321

323-
# Make it possible to check for reordered named columns on transform
324-
self._has_str_cols = any(_determine_key_type(cols) == 'str'
325-
for cols in self._columns)
326-
if hasattr(X, 'columns'):
327-
self._df_columns = X.columns
328-
329322
self._n_features = X.shape[1]
330323
cols = []
331324
for columns in self._columns:
@@ -362,12 +355,12 @@ def get_feature_names(self):
362355
hasattr(column, '__len__') and not len(column)):
363356
continue
364357
if trans == 'passthrough':
365-
if hasattr(self, '_df_columns'):
358+
if self._feature_names_in is not None:
366359
if ((not isinstance(column, slice))
367360
and all(isinstance(col, str) for col in column)):
368361
feature_names.extend(column)
369362
else:
370-
feature_names.extend(self._df_columns[column])
363+
feature_names.extend(self._feature_names_in[column])
371364
else:
372365
indices = np.arange(self._n_features)
373366
feature_names.extend(['x%d' % i for i in indices[column]])
@@ -441,7 +434,7 @@ def _fit_transform(self, X, y, func, fitted=False):
441434
message_clsname='ColumnTransformer',
442435
message=self._log_message(name, idx, len(transformers)))
443436
for idx, (name, trans, column, weight) in enumerate(
444-
self._iter(fitted=fitted, replace_strings=True), 1))
437+
transformers, 1))
445438
except ValueError as e:
446439
if "Expected 2D array, got 1D array instead" in str(e):
447440
raise ValueError(_ERR_MSG_1DCOLUMN) from e
@@ -606,9 +599,9 @@ def _sk_visual_block_(self):
606599
transformers = self.transformers
607600
elif hasattr(self, "_remainder"):
608601
remainder_columns = self._remainder[2]
609-
if hasattr(self, '_df_columns'):
602+
if self._feature_names_in is not None:
610603
remainder_columns = (
611-
self._df_columns[remainder_columns].tolist()
604+
self._feature_names_in[remainder_columns].tolist()
612605
)
613606
transformers = chain(self.transformers,
614607
[('remainder', self.remainder,

0 commit comments

Comments
 (0)
0