8000 Made changes to pipeline to allow more general parameter passing · jcrudy/scikit-learn@87db631 · GitHub
[go: up one dir, main page]

Skip to content

Commit 87db631

Browse files
committed
Made changes to pipeline to allow more general parameter passing
1 parent ef872fb commit 87db631

File tree

2 files changed

+41
-24
lines changed

2 files changed

+41
-24
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,5 @@ benchmarks/bench_covertype_data/
4444

4545
*.prefs
4646
.pydevproject
47+
48+
.project

sklearn/pipeline.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,20 @@ def get_params(self, deep=True):
108108
return out
109109

110110
# Estimator interface
111-
112-
def _pre_transform(self, X, y=None, **fit_params):
113-
fit_params_steps = dict((step, {}) for step, _ in self.steps)
114-
for pname, pval in six.iteritems(fit_params):
111+
112+
def _extract_params(self, **params):
113+
params_steps = dict((step, {}) for step, _ in self.steps)
114+
for pname, pval in six.iteritems(params):
115115
step, param = pname.split('__', 1)
116-
fit_params_steps[step][param] = pval
116+
params_steps[step][param] = pval
117+
return params_steps
118+
119+
def _pre_transform(self, X, y=None, **fit_params):
120+
# fit_params_steps = dict((step, {}) for step, _ in self.steps)
121+
# for pname, pval in six.iteritems(fit_params):
122+
# step, param = pname.split('__', 1)
123+
# fit_params_steps[step][param] = pval
124+
fit_params_steps = self._extract_params(**fit_params)
117125
Xt = X
118126
for name, transform in self.steps[:-1]:
119127
if hasattr(transform, "fit_transform"):
@@ -141,64 +149,71 @@ def fit_transform(self, X, y=None, **fit_params):
141149
else:
142150
return self.steps[-1][-1].fit(Xt, y, **fit_params).transform(Xt)
143151

144-
def predict(self, X):
152+
def predict(self, X, **params):
145153
"""Applies transforms to the data, and the predict method of the
146154
final estimator. Valid only if the final estimator implements
147155
predict."""
156+
params = self._extract_params(**params)
148157
Xt = X
149158
for name, transform in self.steps[:-1]:
150-
Xt = transform.transform(Xt)
151-
return self.steps[-1][-1].predict(Xt)
159+
Xt = transform.transform(Xt, **params[name])
160+
return self.steps[-1][-1].predict(Xt, **params[self.steps[-1][0]])
152161

153-
def predict_proba(self, X):
162+
def predict_proba(self, X, **params):
154163
"""Applies transforms to the data, and the predict_proba method of the
155164
final estimator. Valid only if the final estimator implements
156165
predict_proba."""
166+
params = self._extract_params(**params)
157167
Xt = X
158168
for name, transform in self.steps[:-1]:
159-
Xt = transform.transform(Xt)
160-
return self.steps[-1][-1].predict_proba(Xt)
169+
Xt = transform.transform(Xt, **params[name])
170+
return self.steps[-1][-1].predict_proba(Xt, **params[self.steps[-1][0]])
161171

162-
def decision_function(self, X):
172+
def decision_function(self, X, **params):
163173
"""Applies transforms to the data, and the decision_function method of
164174
the final estimator. Valid only if the final estimator implements
165175
decision_function."""
176+
params = self._extract_params(**params)
166177
Xt = X
167178
for name, transform in self.steps[:-1]:
168-
Xt = transform.transform(Xt)
169-
return self.steps[-1][-1].decision_function(Xt)
179+
Xt = transform.transform(Xt, **params[name])
180+
return self.steps[-1][-1].decision_function(Xt, **params[self.steps[-1][0]])
170181

171-
def predict_log_proba(self, X):
182+
def predict_log_proba(self, X, **params):
183+
params = self._extract_params(**params)
172184
Xt = X
173185
for name, transform in self.steps[:-1]:
174-
Xt = transform.transform(Xt)
175-
return self.steps[-1][-1].predict_log_proba(Xt)
186+
Xt = transform.transform(Xt, **params[name])
187+
return self.steps[-1][-1].predict_log_proba(Xt, **params[self.steps[-1][0]])
176188

177-
def transform(self, X):
189+
def transform(self, X, **params):
178190
"""Applies transforms to the data, and the transform method of the
179191
final estimator. Valid only if the final estimator implements
180192
transform."""
193+
params = self._extract_params(**params)
181194
Xt = X
182195
for name, transform in self.steps:
183-
Xt = transform.transform(Xt)
196+
Xt = transform.transform(Xt, **params[name])
184197
return Xt
185198

186-
def inverse_transform(self, X):
199+
def inverse_transform(self, X, **params):
200+
params = self._extract_params(**params)
187201
if X.ndim == 1:
188202
X = X[None, :]
189203
Xt = X
190204
for name, step in self.steps[::-1]:
191-
Xt = step.inverse_transform(Xt)
205+
Xt = step.inverse_transform(Xt, **params[name])
192206
return Xt
193207

194-
def score(self, X, y=None):
208+
def score(self, X, y=None, **params):
195209
"""Applies transforms to the data, and the score method of the
196210
final estimator. Valid only if the final estimator implements
197211
score."""
212+
params = self._extract_params(**params)
198213
Xt = X
199214
for name, transform in self.steps[:-1]:
200-
Xt = transform.transform(Xt)
201-
return self.steps[-1][-1].score(Xt, y)
215+
Xt = transform.transform(Xt, **params[name])
216+
return self.steps[-1][-1].score(Xt, y, **params[self.steps[-1][0]])
202217

203218
@property
204219
def _pairwise(self):

0 commit comments

Comments
 (0)
0