@@ -211,20 +211,29 @@ def set_params(self, **kwargs):
211211 self ._set_params ('_transformers' , ** kwargs )
212212 return self
213213
214- def _iter (self , X = None , fitted = False , replace_strings = False ):
215- """Generate (name, trans, column, weight) tuples
214+ def _iter (self , fitted = False , replace_strings = False ):
215+ """
216+ Generate (name, trans, X_subset, weight, column) tuples.
217+
218+ If fitted=True, use the fitted transformers, else use the
219+ user specified transformers updated with converted column names
220+ and potentially appended with transformer for remainder.
221+
216222 """
217223 if fitted :
218224 transformers = self .transformers_
219225 else :
220- transformers = self .transformers
226+ # interleave the validated column specifiers
227+ transformers = [
228+ (name , trans , column ) for (name , trans , _ ), column
229+ in zip (self .transformers , self ._columns )
230+ ]
231+ # add transformer tuple for remainder
221232 if self ._remainder [2 ] is not None :
222233 transformers = chain (transformers , [self ._remainder ])
223234 get_weight = (self .transformer_weights or {}).get
224235
225236 for name , trans , column in transformers :
226- sub = None if X is None else _get_column (X , column )
227-
228237 if replace_strings :
229238 # replace 'passthrough' with identity transformer and
230239 # skip in case of 'drop'
@@ -235,7 +244,7 @@ def _iter(self, X=None, fitted=False, replace_strings=False):
235244 elif trans == 'drop' :
236245 continue
237246
238- yield (name , trans , sub , get_weight (name ))
247+ yield (name , trans , column , get_weight (name ))
239248
240249 def _validate_transformers (self ):
241250 if not self .transformers :
@@ -257,6 +266,17 @@ def _validate_transformers(self):
257266 "specifiers. '%s' (type %s) doesn't." %
258267 (t , type (t )))
259268
269+ def _validate_column_callables (self , X ):
270+ """
271+ Converts callable column specifications.
272+ """
273+ columns = []
274+ for _ , _ , column in self .transformers :
275+ if callable (column ):
276+ column = column (X )
277+ columns .append (column )
278+ self ._columns = columns
279+
260280 def _validate_remainder (self , X ):
261281 """
262282 Validates ``remainder`` and defines ``_remainder`` targeting
@@ -274,7 +294,7 @@ def _validate_remainder(self, X):
274294
275295 n_columns = X .shape [1 ]
276296 cols = []
277- for _ , _ , columns in self .transformers :
297+ for columns in self ._columns :
278298 cols .extend (_get_column_indices (X , columns ))
279299 remaining_idx = sorted (list (set (range (n_columns )) - set (cols ))) or None
280300
@@ -320,35 +340,32 @@ def get_feature_names(self):
320340
321341 def _update_fitted_transformers (self , transformers ):
322342 # transformers are fitted; excludes 'drop' cases
323- transformers = iter (transformers )
343+ fitted_transformers = iter (transformers )
324344 transformers_ = []
325345
326- transformer_iter = self .transformers
327- if self ._remainder [2 ] is not None :
328- transformer_iter = chain (transformer_iter , [self ._remainder ])
329-
330- for name , old , column in transformer_iter :
346+ for name , old , column , _ in self ._iter ():
331347 if old == 'drop' :
332348 trans = 'drop'
333349 elif old == 'passthrough' :
334350 # FunctionTransformer is present in list of transformers,
335351 # so get next transformer, but save original string
336- next (transformers )
352+ next (fitted_transformers )
337353 trans = 'passthrough'
338354 else :
339- trans = next (transformers )
355+ trans = next (fitted_transformers )
340356 transformers_ .append ((name , trans , column ))
341357
342358 # sanity check that transformers is exhausted
343- assert not list (transformers )
359+ assert not list (fitted_transformers )
344360 self .transformers_ = transformers_
345361
346362 def _validate_output (self , result ):
347363 """
348364 Ensure that the output of each transformer is 2D. Otherwise
349365 hstack can raise an error or produce incorrect results.
350366 """
351- names = [name for name , _ , _ , _ in self ._iter (replace_strings = True )]
367+ names = [name for name , _ , _ , _ in self ._iter (fitted = True ,
368+ replace_strings = True )]
352369 for Xs , name in zip (result , names ):
353370 if not getattr (Xs , 'ndim' , 0 ) == 2 :
354371 raise ValueError (
@@ -366,9 +383,9 @@ def _fit_transform(self, X, y, func, fitted=False):
366383 try :
367384 return Parallel (n_jobs = self .n_jobs )(
368385 delayed (func )(clone (trans ) if not fitted else trans ,
369- X_sel , y , weight )
370- for _ , trans , X_sel , weight in self ._iter (
371- X = X , fitted = fitted , replace_strings = True ))
386+ _get_column ( X , column ) , y , weight )
387+ for _ , trans , column , weight in self ._iter (
388+ fitted = fitted , replace_strings = True ))
372389 except ValueError as e :
373390 if "Expected 2D array, got 1D array instead" in str (e ):
374391 raise ValueError (_ERR_MSG_1DCOLUMN )
@@ -419,8 +436,9 @@ def fit_transform(self, X, y=None):
419436 sparse matrices.
420437
421438 """
422- self ._validate_remainder (X )
423439 self ._validate_transformers ()
440+ self ._validate_column_callables (X )
441+ self ._validate_remainder (X )
424442
425443 result = self ._fit_transform (X , y , _fit_transform_one )
426444
@@ -545,9 +563,6 @@ def _get_column(X, key):
545563 can use any hashable object as key).
546564
547565 """
548- if callable (key ):
549- key = key (X )
550-
551566 # check whether we have string column names or integers
552567 if _check_key_type (key , int ):
553568 column_names = False
@@ -589,9 +604,6 @@ def _get_column_indices(X, key):
589604 """
590605 n_columns = X .shape [1 ]
591606
592- if callable (key ):
593- key = key (X )
594-
595607 if _check_key_type (key , int ):
596608 if isinstance (key , int ):
597609 return [key ]
0 commit comments