From 9a14d86fb33701a677d8ebef529a20b3ff35ac74 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 27 Mar 2020 08:25:57 -0400 Subject: [PATCH] remove coupling between methods --- sklearn/pipeline.py | 54 +++++++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 21 deletions(-) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 64d2de70df531..b44ad19cc187e 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -258,17 +258,7 @@ def _log_message(self, step_idx): len(self.steps), name) - # Estimator interface - - def _fit(self, X, y=None, **fit_params): - # shallow copy of steps - this should really be steps_ - self.steps = list(self.steps) - self._validate_steps() - # Setup the memory - memory = check_memory(self.memory) - - fit_transform_one_cached = memory.cache(_fit_transform_one) - + def _check_fit_params(self, **fit_params): fit_params_steps = {name: {} for name, step in self.steps if step is not None} for pname, pval in fit_params.items(): @@ -281,6 +271,19 @@ def _fit(self, X, y=None, **fit_params): "=sample_weight)`.".format(pname)) step, param = pname.split('__', 1) fit_params_steps[step][param] = pval + return fit_params_steps + + # Estimator interface + + def _fit(self, X, y=None, **fit_params_steps): + # shallow copy of steps - this should really be steps_ + self.steps = list(self.steps) + self._validate_steps() + # Setup the memory + memory = check_memory(self.memory) + + fit_transform_one_cached = memory.cache(_fit_transform_one) + for (step_idx, name, transformer) in self._iter(with_final=False, @@ -318,9 +321,7 @@ def _fit(self, X, y=None, **fit_params): # transformer. This is necessary when loading the transformer # from the cache. self.steps[step_idx] = (name, fitted_transformer) - if self._final_estimator == 'passthrough': - return X, {} - return X, fit_params_steps[self.steps[-1][0]] + return X def fit(self, X, y=None, **fit_params): """Fit the model @@ -348,11 +349,14 @@ def fit(self, X, y=None, **fit_params): self : Pipeline This estimator """ - Xt, fit_params = self._fit(X, y, **fit_params) + fit_params_steps = self._check_fit_params(**fit_params) + Xt = self._fit(X, y, **fit_params_steps) with _print_elapsed_time('Pipeline', self._log_message(len(self.steps) - 1)): if self._final_estimator != 'passthrough': - self._final_estimator.fit(Xt, y, **fit_params) + fit_params_last_step = fit_params_steps[self.steps[-1][0]] + self._final_estimator.fit(Xt, y, **fit_params_last_step) + return self def fit_transform(self, X, y=None, **fit_params): @@ -382,16 +386,20 @@ def fit_transform(self, X, y=None, **fit_params): Xt : array-like of shape (n_samples, n_transformed_features) Transformed samples """ + fit_params_steps = self._check_fit_params(**fit_params) + Xt = self._fit(X, y, **fit_params_steps) + last_step = self._final_estimator - Xt, fit_params = self._fit(X, y, **fit_params) with _print_elapsed_time('Pipeline', self._log_message(len(self.steps) - 1)): if last_step == 'passthrough': return Xt + fit_params_last_step = fit_params_steps[self.steps[-1][0]] if hasattr(last_step, 'fit_transform'): - return last_step.fit_transform(Xt, y, **fit_params) + return last_step.fit_transform(Xt, y, **fit_params_last_step) else: - return last_step.fit(Xt, y, **fit_params).transform(Xt) + return last_step.fit(Xt, y, + **fit_params_last_step).transform(Xt) @if_delegate_has_method(delegate='_final_estimator') def predict(self, X, **predict_params): @@ -447,10 +455,14 @@ def fit_predict(self, X, y=None, **fit_params): ------- y_pred : array-like """ - Xt, fit_params = self._fit(X, y, **fit_params) + fit_params_steps = self._check_fit_params(**fit_params) + Xt = self._fit(X, y, **fit_params_steps) + + fit_params_last_step = fit_params_steps[self.steps[-1][0]] with _print_elapsed_time('Pipeline', self._log_message(len(self.steps) - 1)): - y_pred = self.steps[-1][-1].fit_predict(Xt, y, **fit_params) + y_pred = self.steps[-1][-1].fit_predict(Xt, y, + **fit_params_last_step) return y_pred @if_delegate_has_method(delegate='_final_estimator')