8000 MNT remove coupling between pipeline methods (#16777) · scikit-learn/scikit-learn@88ce8cd · GitHub
[go: up one dir, main page]

Skip to content

Commit 88ce8cd

Browse files
authored
MNT remove coupling between pipeline methods (#16777)
1 parent 0a866ec commit 88ce8cd

File tree

1 file changed

+33
-21
lines changed

1 file changed

+33
-21
lines changed

sklearn/pipeline.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -258,17 +258,7 @@ def _log_message(self, step_idx):
258258
len(self.steps),
259259
name)
260260

261-
# Estimator interface
262-
263-
def _fit(self, X, y=None, **fit_params):
264-
# shallow copy of steps - this should really be steps_
265-
self.steps = list(self.steps)
266-
self._validate_steps()
267-
# Setup the memory
268-
memory = check_memory(self.memory)
269-
270-
fit_transform_one_cached = memory.cache(_fit_transform_one)
271-
261+
def _check_fit_params(self, **fit_params):
272262
fit_params_steps = {name: {} for name, step in self.steps
273263
if step is not None}
274264
for pname, pval in fit_params.items():
@@ -281,6 +271,19 @@ def _fit(self, X, y=None, **fit_params):
281271
"=sample_weight)`.".format(pname))
282272
step, param = pname.split('__', 1)
283273
fit_params_steps[step][param] = pval
274+
return fit_params_steps
275+
276+
# Estimator interface
277+
278+
def _fit(self, X, y=None, **fit_params_steps):
279+
# shallow copy of steps - this should really be steps_
280+
self.steps = list(self.steps)
281+
self._validate_steps()
282+
# Setup the memory
283+
memory = check_memory(self.memory)
284+
285+
fit_transform_one_cached = memory.cache(_fit_transform_one)
286+
284287
for (step_idx,
285288
name,
286289
transformer) in self._iter(with_final=False,
@@ -318,9 +321,7 @@ def _fit(self, X, y=None, **fit_params):
318321
# transformer. This is necessary when loading the transformer
319322
# from the cache.
320323
self.steps[step_idx] = (name, fitted_transformer)
321-
if self._final_estimator == 'passthrough':
322-
return X, {}
323-
return X, fit_params_steps[self.steps[-1][0]]
324+
return X
324325

325326
def fit(self, X, y=None, **fit_params):
326327
"""Fit the model
@@ -348,11 +349,14 @@ def fit(self, X, y=None, **fit_params):
348349
self : Pipeline
349350
This estimator
350351
"""
351-
Xt, fit_params = self._fit(X, y, **fit_params)
352+
fit_params_steps = self._check_fit_params(**fit_params)
353+
Xt = self._fit(X, y, **fit_params_steps)
352354
with _print_elapsed_time('Pipeline',
353355
self._log_message(len(self.steps) - 1)):
354356
if self._final_estimator != 'passthrough':
355-
self._final_estimator.fit(Xt, y, **fit_params)
357+
fit_params_last_step = fit_params_steps[self.steps[-1][0]]
358+
self._final_estimator.fit(Xt, y, **fit_params_last_step)
359+
356360
return self
357361

358362
def fit_transform(self, X, y=None, **fit_params):
@@ -382,16 +386,20 @@ def fit_transform(self, X, y=None, **fit_params):
382386
Xt : array-like of shape (n_samples, n_transformed_features)
383387
Transformed samples
384388
"""
389+
fit_params_steps = self._check_fit_params(**fit_params)
390+
Xt = self._fit(X, y, **fit_params_steps)
391+
385392
last_step = self._final_estimator
386-
Xt, fit_params = self._fit(X, y, **fit_params)
387393
with _print_elapsed_time('Pipeline',
388394
self._log_message(len(self.steps) - 1)):
389395
if last_step == 'passthrough':
390396
return Xt
397+
fit_params_last_step = fit_params_steps[self.steps[-1][0]]
391398
if hasattr(last_step, 'fit_transform'):
392-
return last_step.fit_transform(Xt, y, **fit_params)
399+
return last_step.fit_transform(Xt, y, **fit_params_last_step)
393400
else:
394-
return last_step.fit(Xt, y, **fit_params).transform(Xt)
401+
return last_step.fit(Xt, y,
402+
**fit_params_last_step).transform(Xt)
395403

396404
@if_delegate_has_method(delegate='_final_estimator')
397405
def predict(self, X, **predict_params):
@@ -447,10 +455,14 @@ def fit_predict(self, X, y=None, **fit_params):
447455
-------
448456
y_pred : array-like
449457
"""
450-
Xt, fit_params = self._fit(X, y, **fit_params)
458+
fit_params_steps = self._check_fit_params(**fit_params)
459+
Xt = self._fit(X, y, **fit_params_steps)
460+
461+
fit_params_last_step = fit_params_steps[self.steps[-1][0]]
451462
with _print_elapsed_time('Pipeline',
452463
self._log_message(len(self.steps) - 1)):
453-
y_pred = self.steps[-1][-1].fit_predict(Xt, y, **fit_params)
464+
y_pred = self.steps[-1][-1].fit_predict(Xt, y,
465+
**fit_params_last_step)
454466
return y_pred
455467

456468
@if_delegate_has_method(delegate='_final_estimator')

0 commit comments

Comments
 (0)
0