@@ -211,20 +211,29 @@ def set_params(self, **kwargs):
211
211
self ._set_params ('_transformers' , ** kwargs )
212
212
return self
213
213
214
- def _iter (self , X = None , fitted = False , replace_strings = False ):
215
- """Generate (name, trans, column, weight) tuples
214
+ def _iter (self , fitted = False , replace_strings = False ):
215
+ """
216
+ Generate (name, trans, X_subset, weight, column) tuples.
217
+
218
+ If fitted=True, use the fitted transformers, else use the
219
+ user specified transformers updated with converted column names
220
+ and potentially appended with transformer for remainder.
221
+
216
222
"""
217
223
if fitted :
218
224
transformers = self .transformers_
219
225
else :
220
- transformers = self .transformers
226
+ # interleave the validated column specifiers
227
+ transformers = [
228
+ (name , trans , column ) for (name , trans , _ ), column
229
+ in zip (self .transformers , self ._columns )
230
+ ]
231
+ # add transformer tuple for remainder
221
232
if self ._remainder [2 ] is not None :
222
233
transformers = chain (transformers , [self ._remainder ])
223
234
get_weight = (self .transformer_weights or {}).get
224
235
225
236
for name , trans , column in transformers :
226
- sub = None if X is None else _get_column (X , column )
227
-
228
237
if replace_strings :
229
238
# replace 'passthrough' with identity transformer and
230
239
# skip in case of 'drop'
@@ -235,7 +244,7 @@ def _iter(self, X=None, fitted=False, replace_strings=False):
235
244
elif trans == 'drop'
10000
span>:
236
245
continue
237
246
238
- yield (name , trans , sub , get_weight (name ))
247
+ yield (name , trans , column , get_weight (name ))
239
248
240
249
def _validate_transformers (self ):
241
250
if not self .transformers :
@@ -257,6 +266,17 @@ def _validate_transformers(self):
257
266
"specifiers. '%s' (type %s) doesn't." %
258
267
(t , type (t )))
259
268
269
+ def _validate_column_callables (self , X ):
270
+ """
271
+ Converts callable column specifications.
272
+ """
273
+ columns = []
274
+ for _ , _ , column in self .transformers :
275
+ if callable (column ):
276
+ column = column (X )
277
+ columns .append (column )
278
+ self ._columns = columns
279
+
260
280
def _validate_remainder (self , X ):
261
281
"""
262
282
Validates ``remainder`` and defines ``_remainder`` targeting
@@ -274,7 +294,7 @@ def _validate_remainder(self, X):
274
294
275
295
n_columns = X .shape [1 ]
276
296
cols = []
277
- for _ , _ , columns in self .transformers :
297
+ for columns in self ._columns :
278
298
cols .extend (_get_column_indices (X , columns ))
279
299
remaining_idx = sorted (list (set (range (n_columns )) - set (cols ))) or None
280
300
@@ -320,35 +340,32 @@ def get_feature_names(self):
320
340
321
341
def _update_fitted_transformers (self , transformers ):
322
342
# transformers are fitted; excludes 'drop' cases
323
- transformers = iter (transformers )
343
+ fitted_transformers = iter (transformers )
324
344
transformers_ = []
325
345
326
- transformer_iter = self .transformers
327
- if self ._remainder [2 ] is not None :
328
- transformer_iter = chain (transformer_iter , [self ._remainder ])
329
-
330
- for name , old , column in transformer_iter :
346
+ for name , old , column , _ in self ._iter ():
331
347
if old == 'drop' :
332
348
trans = 'drop'
333
349
elif old == 'passthrough' :
334
350
# FunctionTransformer is present in list of transformers,
335
351
# so get next transformer, but save original string
336
- next (transformers )
352
+ next (fitted_transformers )
337
353
trans = 'passthrough'
338
354
else :
339
- trans = next (transformers )
355
+ trans = next (fitted_transformers )
340
356
transformers_ .append ((name , trans , column ))
341
357
342
358
# sanity check that transformers is exhausted
343
- assert not list (transformers )
359
+ assert not list (fitted_transformers )
344
360
self .transformers_ = transformers_
345
361
346
362
def _validate_output (self , result ):
347
363
"""
348
364
Ensure that the output of each transformer is 2D. Otherwise
349
365
hstack can raise an error or produce incorrect results.
350
366
"""
351
- names = [name for name , _ , _ , _ in self ._iter (replace_strings = True )]
367
+ names = [name for name , _ , _ , _ in self ._iter (fitted = True ,
368
+ replace_strings = True )]
352
369
for Xs , name in zip (result , names ):
353
370
if not getattr (Xs , 'ndim' , 0 ) == 2 :
354
371
raise ValueError (
@@ -366,9 +383,9 @@ def _fit_transform(self, X, y, func, fitted=False):
366
383
try :
367
384
return Parallel (n_jobs = self .n_jobs )(
368
385
delayed (func )(clone (trans ) if not fitted else trans ,
369
- X_sel , y , weight )
370
- for _ , trans , X_sel , weight in self ._iter (
371
- X = X , fitted = fitted , replace_strings = True ))
386
+ _get_column ( X , column ) , y , weight )
387
+ for _ , trans , column , weight in self ._iter (
388
+ fitted = fitted , replace_strings = True ))
372
389
except ValueError as e :
373
390
if "Expected 2D array, got 1D array instead" in str (e ):
374
391
raise ValueError (_ERR_MSG_1DCOLUMN )
@@ -419,8 +436,9 @@ def fit_transform(self, X, y=None):
419
436
sparse matrices.
420
437
421
438
"""
422
- self ._validate_remainder (X )
423
439
self ._validate_transformers ()
440
+ self ._validate_column_callables (X )
441
+ self ._validate_remainder (X )
424
442
425
443
result = self ._fit_transform (X , y , _fit_transform_one )
426
444
@@ -545,9 +563,6 @@ def _get_column(X, key):
545
563
can use any hashable object as key).
546
564
547
565
"""
548
- if callable (key ):
549
- key = key (X )
550
-
551
566
# check whether we have string column names or integers
552
567
if _check_key_type (key , int ):
553
568
column_names = False
@@ -589,9 +604,6 @@ def _get_column_indices(X, key):
589
604
"""
590
605
n_columns = X .shape [1 ]
591
606
592
- if callable (key ):
593
- key = key (X )
594
-
595
607
if _check_key_type (key , int ):
596
608
if isinstance (key , int ):
597
609
return [key ]
0 commit comments