8000 [MRG +1] ColumnTransformer: store evaluated function column specifier… · scikit-learn/scikit-learn@1dc0205 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1dc0205

Browse files
jorisvandenbosschejnothman
authored andcommitted
[MRG +1] ColumnTransformer: store evaluated function column specifier during fit (#12107)
1 parent e6f7b91 commit 1dc0205

File tree

2 files changed

+43
-27
lines changed

2 files changed

+43
-27
lines changed

sklearn/compose/_column_transformer.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -211,20 +211,29 @@ def set_params(self, **kwargs):
211211
self._set_params('_transformers', **kwargs)
212212
return self
213213

214-
def _iter(self, X=None, fitted=False, replace_strings=False):
215-
"""Generate (name, trans, column, weight) tuples
214+
def _iter(self, fitted=False, replace_strings=False):
215+
"""
216+
Generate (name, trans, X_subset, weight, column) tuples.
217+
218+
If fitted=True, use the fitted transformers, else use the
219+
user specified transformers updated with converted column names
220+
and potentially appended with transformer for remainder.
221+
216222
"""
217223
if fitted:
218224
transformers = self.transformers_
219225
else:
220-
transformers = self.transformers
226+
# interleave the validated column specifiers
227+
transformers = [
228+
(name, trans, column) for (name, trans, _), column
229+
in zip(self.transformers, self._columns)
230+
]
231+
# add transformer tuple for remainder
221232
if self._remainder[2] is not None:
222233
transformers = chain(transformers, [self._remainder])
223234
get_weight = (self.transformer_weights or {}).get
224235

225236
for name, trans, column in transformers:
226-
sub = None if X is None else _get_column(X, column)
227-
228237
if replace_strings:
229238
# replace 'passthrough' with identity transformer and
230239
# skip in case of 'drop'
@@ -235,7 +244,7 @@ def _iter(self, X=None, fitted=False, replace_strings=False):
235244
elif trans == 'drop':
236245
continue
237246

238-
yield (name, trans, sub, get_weight(name))
247+
yield (name, trans, column, get_weight(name))
239248

240249
def _validate_transformers(self):
241250
if not self.transformers:
@@ -257,6 +266,17 @@ def _validate_transformers(self):
257266
"specifiers. '%s' (type %s) doesn't." %
258267
(t, type(t)))
259268

269+
def _validate_column_callables(self, X):
270+
"""
271+
Converts callable column specifications.
272+
"""
273+
columns = []
274+
for _, _, column in self.transformers:
275+
if callable(column):
276+
column = column(X)
277+
columns.append(column)
278+
self._columns = columns
279+
260280
def _validate_remainder(self, X):
261281
"""
262282
Validates ``remainder`` and defines ``_remainder`` targeting
@@ -274,7 +294,7 @@ def _validate_remainder(self, X):
274294

275295
n_columns = X.shape[1]
276296
cols = []
277-
for _, _, columns in self.transformers:
297+
for columns in self._columns:
278298
cols.extend(_get_column_indices(X, columns))
279299
remaining_idx = sorted(list(set(range(n_columns)) - set(cols))) or None
280300

@@ -320,35 +340,32 @@ def get_feature_names(self):
320340

321341
def _update_fitted_transformers(self, transformers):
322342
# transformers are fitted; excludes 'drop' cases
323-
transformers = iter(transformers)
343+
fitted_transformers = iter(transformers)
324344
transformers_ = []
325345

326-
transformer_iter = self.transformers
327-
if self._remainder[2] is not None:
328-
transformer_iter = chain(transformer_iter, [self._remainder])
329-
330-
for name, old, column in transformer_iter:
346+
for name, old, column, _ in self._iter():
331347
if old == 'drop':
332348
trans = 'drop'
333349
elif old == 'passthrough':
334350
# FunctionTransformer is present in list of transformers,
335351
# so get next transformer, but save original string
336-
next(transformers)
352+
next(fitted_transformers)
337353
trans = 'passthrough'
338354
else:
339-
trans = next(transformers)
355+
trans = next(fitted_transformers)
340356
transformers_.append((name, trans, column))
341357

342358
# sanity check that transformers is exhausted
343-
assert not list(transformers)
359+
assert not list(fitted_transformers)
344360
self.transformers_ = transformers_
345361

346362
def _validate_output(self, result):
347363
"""
348364
Ensure that the output of each transformer is 2D. Otherwise
349365
hstack can raise an error or produce incorrect results.
350366
"""
351-
names = [name for name, _, _, _ in self._iter(replace_strings=True)]
367+
names = [name for name, _, _, _ in self._iter(fitted=True,
368+
replace_strings=True)]
352369
for Xs, name in zip(result, names):
353370
if not getattr(Xs, 'ndim', 0) == 2:
354371
raise ValueError(
@@ -366,9 +383,9 @@ def _fit_transform(self, X, y, func, fitted=False):
366383
try:
367384
return Parallel(n_jobs=self.n_jobs)(
368385
delayed(func)(clone(trans) if not fitted else trans,
369-
X_sel, y, weight)
370-
for _, trans, X_sel, weight in self._iter(
371-
X=X, fitted=fitted, replace_strings=True))
386+
_get_column(X, column), y, weight)
387+
for _, trans, column, weight in self._iter(
388+
fitted=fitted, replace_strings=True))
372389
except ValueError as e:
373390
if "Expected 2D array, got 1D array instead" in str(e):
374391
raise ValueError(_ERR_MSG_1DCOLUMN)
@@ -419,8 +436,9 @@ def fit_transform(self, X, y=None):
419436
sparse matrices.
420437
421438
"""
422-
self._validate_remainder(X)
423439
self._validate_transformers()
440+
self._validate_column_callables(X)
441+
self._validate_remainder(X)
424442

425443
result = self._fit_transform(X, y, _fit_transform_one)
426444

@@ -545,9 +563,6 @@ def _get_column(X, key):
545563
can use any hashable object as key).
546564
547565
"""
548-
if callable(key):
549-
key = key(X)
550-
551566
# check whether we have string column names or integers
552567
if _check_key_type(key, int):
553568
column_names = False
@@ -589,9 +604,6 @@ def _get_column_indices(X, key):
589604
"""
590605
n_columns = X.shape[1]
591606

592-
if callable(key):
593-
key = key(X)
594-
595607
if _check_key_type(key, int):
596608
if isinstance(key, int):
597609
return [key]

sklearn/compose/tests/test_column_transformer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,8 @@ def func(X):
873873
remainder='drop')
874874
assert_array_equal(ct.fit_transform(X_array), X_res_first)
875875
assert_array_equal(ct.fit(X_array).transform(X_array), X_res_first)
876+
assert callable(ct.transformers[0][2])
877+
assert ct.transformers_[0][2] == [0]
876878

877879
pd = pytest.importorskip('pandas')
878880
X_df = pd.DataFrame(X_array, columns=['first', 'second'])
@@ -886,3 +888,5 @@ def func(X):
886888
remainder='drop')
887889
assert_array_equal(ct.fit_transform(X_df), X_res_first)
888890
assert_array_equal(ct.fit(X_df).transform(X_df), X_res_first)
891+
assert callable(ct.transformers[0][2])
892+
assert ct.transformers_[0][2] == ['first']

0 commit comments

Comments
 (0)
0