@@ -258,17 +258,7 @@ def _log_message(self, step_idx):
258
258
len (self .steps ),
259
259
name )
260
260
261
- # Estimator interface
262
-
263
- def _fit (self , X , y = None , ** fit_params ):
264
- # shallow copy of steps - this should really be steps_
265
- self .steps = list (self .steps )
266
- self ._validate_steps ()
267
- # Setup the memory
268
- memory = check_memory (self .memory )
269
-
270
- fit_transform_one_cached = memory .cache (_fit_transform_one )
271
-
261
+ def _check_fit_params (self , ** fit_params ):
272
262
fit_params_steps = {name : {} for name , step in self .steps
273
263
if step is not None }
274
264
for pname , pval in fit_params .items ():
@@ -281,6 +271,19 @@ def _fit(self, X, y=None, **fit_params):
281
271
"=sample_weight)`." .format (pname ))
282
272
step , param = pname .split ('__' , 1 )
283
273
fit_params_steps [step ][param ] = pval
274
+ return fit_params_steps
275
+
276
+ # Estimator interface
277
+
278
+ def _fit (self , X , y = None , ** fit_params_steps ):
279
+ # shallow copy of steps - this should really be steps_
280
+ self .steps = list (self .steps )
281
+ self ._validate_steps ()
282
+ # Setup the memory
283
+ memory = check_memory (self .memory )
284
+
285
+ fit_transform_one_cached = memory .cache (_fit_transform_one )
286
+
284
287
for (step_idx ,
285
288
name ,
286
289
transformer ) in self ._iter (with_final = False ,
@@ -318,9 +321,7 @@ def _fit(self, X, y=None, **fit_params):
318
321
# transformer. This is necessary when loading the transformer
319
322
# from the cache.
320
323
self .steps [step_idx ] = (name , fitted_transformer )
321
- if self ._final_estimator == 'passthrough' :
322
- return X , {}
323
- return X , fit_params_steps [self .steps [- 1 ][0 ]]
324
+ return X
324
325
325
326
def fit (self , X , y = None , ** fit_params ):
326
327
"""Fit the model
@@ -348,11 +349,14 @@ def fit(self, X, y=None, **fit_params):
348
349
self : Pipeline
349
350
This estimator
350
351
"""
351
- Xt , fit_params = self ._fit (X , y , ** fit_params )
352
+ fit_params_steps = self ._check_fit_params (** fit_params )
353
+ Xt = self ._fit (X , y , ** fit_params_steps )
352
354
with _print_elapsed_time ('Pipeline' ,
353
355
self ._log_message (len (self .steps ) - 1 )):
354
356
if self ._final_estimator != 'passthrough' :
355
- self ._final_estimator .fit (Xt , y , ** fit_params )
357
+ fit_params_last_step = fit_params_steps [self .steps [- 1 ][0 ]]
358
+ self ._final_estimator .fit (Xt , y , ** fit_params_last_step )
359
+
356
360
return self
357
361
358
362
def fit_transform (self , X , y = None , ** fit_params ):
@@ -382,16 +386,20 @@ def fit_transform(self, X, y=None, **fit_params):
382
386
Xt : array-like of shape (n_samples, n_transformed_features)
383
387
Transformed samples
384
388
"""
389
+ fit_params_steps = self ._check_fit_params (** fit_params )
390
+ Xt = self ._fit (X , y , ** fit_params_steps )
391
+
385
392
last_step = self ._final_estimator
386
- Xt , fit_params = self ._fit (X , y , ** fit_params )
387
393
with _print_elapsed_time ('Pipeline' ,
388
394
self ._log_message (len (self .steps ) - 1 )):
389
395
if last_step == 'passthrough' :
390
396
return Xt
397
+ fit_params_last_step = fit_params_steps [self .steps [- 1 ][0 ]]
391
398
if hasattr (last_step , 'fit_transform' ):
392
- return last_step .fit_transform (Xt , y , ** fit_params )
399
+ return last_step .fit_transform (Xt , y , ** fit_params_last_step )
393
400
else :
394
- return last_step .fit (Xt , y , ** fit_params ).transform (Xt )
401
+ return last_step .fit (Xt , y ,
402
+ ** fit_params_last_step ).transform (Xt )
395
403
396
404
@if_delegate_has_method (delegate = '_final_estimator' )
397
405
def predict (self , X , ** predict_params ):
@@ -447,10 +455,14 @@ def fit_predict(self, X, y=None, **fit_params):
447
455
-------
448
456
y_pred : array-like
449
457
"""
450
- Xt , fit_params = self ._fit (X , y , ** fit_params )
458
+ fit_params_steps = self ._check_fit_params (** fit_params )
459
+ Xt = self ._fit (X , y , ** fit_params_steps )
460
+
461
+ fit_params_last_step = fit_params_steps [self .steps [- 1 ][0 ]]
451
462
with _print_elapsed_time ('Pipeline' ,
452
463
self ._log_message (len (self .steps ) - 1 )):
453
- y_pred = self .steps [- 1 ][- 1 ].fit_predict (Xt , y , ** fit_params )
464
+ y_pred = self .steps [- 1 ][- 1 ].fit_predict (Xt , y ,
465
+ ** fit_params_last_step )
454
466
return y_pred
455
467
456
468
@if_delegate_has_method (delegate = '_final_estimator' )
0 commit comments