@@ -108,12 +108,20 @@ def get_params(self, deep=True):
108
108
return out
109
109
110
110
# Estimator interface
111
-
112
- def _pre_transform (self , X , y = None , ** fit_params ):
113
- fit_params_steps = dict ((step , {}) for step , _ in self .steps )
114
- for pname , pval in six .iteritems (fit_params ):
111
+
112
+ def _extract_params (self , ** params ):
113
+ params_steps = dict ((step , {}) for step , _ in self .steps )
114
+ for pname , pval in six .iteritems (params ):
115
115
step , param = pname .split ('__' , 1 )
116
- fit_params_steps [step ][param ] = pval
116
+ params_steps [step ][param ] = pval
117
+ return params_steps
118
+
119
+ def _pre_transform (self , X , y = None , ** fit_params ):
120
+ # fit_params_steps = dict((step, {}) for step, _ in self.steps)
121
+ # for pname, pval in six.iteritems(fit_params):
122
+ # step, param = pname.split('__', 1)
123
+ # fit_params_steps[step][param] = pval
124
+ fit_params_steps = self ._extract_params (** fit_params )
117
125
Xt = X
118
126
for name , transform in self .steps [:- 1 ]:
119
127
if hasattr (transform , "fit_transform" ):
@@ -141,64 +149,71 @@ def fit_transform(self, X, y=None, **fit_params):
141
149
else :
142
150
return self .steps [- 1 ][- 1 ].fit (Xt , y , ** fit_params ).transform (Xt )
143
151
144
- def predict (self , X ):
152
+ def predict (self , X , ** params ):
145
153
"""Applies transforms to the data, and the predict method of the
146
154
final estimator. Valid only if the final estimator implements
147
155
predict."""
156
+ params = self ._extract_params (** params )
148
157
Xt = X
149
158
for name , transform in self .steps [:- 1 ]:
150
- Xt = transform .transform (Xt )
151
- return self .steps [- 1 ][- 1 ].predict (Xt )
159
+ Xt = transform .transform (Xt , ** params [ name ] )
160
+ return self .steps [- 1 ][- 1 ].predict (Xt , ** params [ self . steps [ - 1 ][ 0 ]] )
152
161
153
- def predict_proba (self , X ):
162
+ def predict_proba (self , X , ** params ):
154
163
"""Applies transforms to the data, and the predict_proba method of the
155
164
final estimator. Valid only if the final estimator implements
156
165
predict_proba."""
166
+ params = self ._extract_params (** params )
157
167
Xt = X
158
168
for name , transform in self .steps [:- 1 ]:
159
- Xt = transform .transform (Xt )
160
- return self .steps [- 1 ][- 1 ].predict_proba (Xt )
169
+ Xt = transform .transform (Xt , ** params [ name ] )
170
+ return self .steps [- 1 ][- 1 ].predict_proba (Xt , ** params [ self . steps [ - 1 ][ 0 ]] )
161
171
162
- def decision_function (self , X ):
172
+ def decision_function (self , X , ** params ):
163
173
"""Applies transforms to the data, and the decision_function method of
164
174
the final estimator. Valid only if the final estimator implements
165
175
decision_function."""
176
+ params = self ._extract_params (** params )
166
177
Xt = X
167
178
for name , transform in self .steps [:- 1 ]:
168
- Xt = transform .transform (Xt )
169
- return self .steps [- 1 ][- 1 ].decision_function (Xt )
179
+ Xt = transform .transform (Xt , ** params [ name ] )
180
+ return self .steps [- 1 ][- 1 ].decision_function (Xt , ** params [ self . steps [ - 1 ][ 0 ]] )
170
181
171
- def predict_log_proba (self , X ):
182
+ def predict_log_proba (self , X , ** params ):
183
+ params = self ._extract_params (** params )
172
184
Xt = X
173
185
for name , transform in self .steps [:- 1 ]:
174
- Xt = transform .transform (Xt )
175
- return self .steps [- 1 ][- 1 ].predict_log_proba (Xt )
186
+ Xt = transform .transform (Xt , ** params [ name ] )
187
+ return self .steps [- 1 ][- 1 ].predict_log_proba (Xt , ** params [ self . steps [ - 1 ][ 0 ]] )
176
188
177
- def transform (self , X ):
189
+ def transform (self , X , ** params ):
178
190
"""Applies transforms to the data, and the transform method of the
179
191
final estimator. Valid only if the final estimator implements
180
192
transform."""
193
+ params = self ._extract_params (** params )
181
194
Xt = X
182
195
for name , transform in self .steps :
183
- Xt = transform .transform (Xt )
196
+ Xt = transform .transform (Xt , ** params [ name ] )
184
197
return Xt
185
198
186
- def inverse_transform (self , X ):
199
+ def inverse_transform (self , X , ** params ):
200
+ params = self ._extract_params (** params )
187
201
if X .ndim == 1 :
188
202
X = X [None , :]
189
203
Xt = X
190
204
for name , step in self .steps [::- 1 ]:
191
- Xt = step .inverse_transform (Xt )
205
+ Xt = step .inverse_transform (Xt , ** params [ name ] )
192
206
return Xt
193
207
194
- def score (self , X , y = None ):
208
+ def score (self , X , y = None , ** params ):
195
209
"""Applies transforms to the data, and the score method of the
196
210
final estimator. Valid only if the final estimator implements
197
211
score."""
212
+ params = self ._extract_params (** params )
198
213
Xt = X
199
214
for name , transform in self .steps [:- 1 ]:
200
- Xt = transform .transform (Xt )
201
- return self .steps [- 1 ][- 1 ].score (Xt , y )
215
+ Xt = transform .transform (Xt , ** params [ name ] )
216
+ return self .steps [- 1 ][- 1 ].score (Xt , y , ** params [ self . steps [ - 1 ][ 0 ]] )
202
217
203
218
@property
204
219
def _pairwise (self ):
0 commit comments