10
10
import numpy as np
11
11
12
12
from ..base import BaseEstimator , MetaEstimatorMixin , _fit_context , clone , is_classifier
13
- from ..metrics import get_scorer_names
13
+ from ..metrics import check_scoring , get_scorer_names
14
14
from ..model_selection import check_cv , cross_val_score
15
+ from ..utils ._metadata_requests import (
16
+ MetadataRouter ,
17
+ MethodMapping ,
18
+ _raise_for_params ,
19
+ _routing_enabled ,
20
+ process_routing ,
21
+ )
15
22
from ..utils ._param_validation import HasMethods , Interval , RealNotInt , StrOptions
16
23
from ..utils ._tags import _safe_tags
17
- from ..utils .metadata_routing import _RoutingNotSupportedMixin
18
24
from ..utils .validation import check_is_fitted
19
25
from ._base import SelectorMixin
20
26
21
27
22
- class SequentialFeatureSelector (
23
- _RoutingNotSupportedMixin , SelectorMixin , MetaEstimatorMixin , BaseEstimator
24
- ):
28
+ class SequentialFeatureSelector (SelectorMixin , MetaEstimatorMixin , BaseEstimator ):
25
29
"""Transformer that performs Sequential Feature Selection.
26
30
27
31
This Sequential Feature Selector adds (forward selection) or
@@ -191,7 +195,7 @@ def __init__(
191
195
# SequentialFeatureSelector.estimator is not validated yet
192
196
prefer_skip_nested_validation = False
193
197
)
194
- def fit (self , X , y = None ):
198
+ def fit (self , X , y = None , ** params ):
195
199
"""Learn the features to select from X.
196
200
197
201
Parameters
@@ -204,11 +208,24 @@ def fit(self, X, y=None):
204
208
Target values. This parameter may be ignored for
205
209
unsupervised learning.
206
210
211
+ **params : dict, default=None
212
+ Parameters to be passed to the underlying `estimator`, `cv`
213
+ and `scorer` objects.
214
+
215
+ .. versionadded:: 1.6
216
+
217
+ Only available if `enable_metadata_routing=True`,
218
+ which can be set by using
219
+ ``sklearn.set_config(enable_metadata_routing=True)``.
220
+ See :ref:`Metadata Routing User Guide <metadata_routing>` for
221
+ more details.
222
+
207
223
Returns
208
224
-------
209
225
self : object
210
226
Returns the instance itself.
211
227
"""
228
+ _raise_for_params (params , self , "fit" )
212
229
tags = self ._get_tags ()
213
230
X = self ._validate_data (
214
231
X ,
@@ -251,9 +268,15 @@ def fit(self, X, y=None):
251
268
252
269
old_score = - np .inf
253
270
is_auto_select = self .tol is not None and self .n_features_to_select == "auto"
271
+
272
+ # We only need to verify the routing here and not use the routed params
273
+ # because internally the actual routing will also take place inside the
274
+ # `cross_val_score` function.
275
+ if _routing_enabled ():
276
+ process_routing (self , "fit" , ** params )
254
277
for _ in range (n_iterations ):
255
278
new_feature_idx , new_score = self ._get_best_new_feature_score (
256
- cloned_estimator , X , y , cv , current_mask
279
+ cloned_estimator , X , y , cv , current_mask , ** params
257
280
)
258
281
if is_auto_select and ((new_score - old_score ) < self .tol ):
259
282
break
@@ -269,7 +292,7 @@ def fit(self, X, y=None):
269
292
270
293
return self
271
294
272
- def _get_best_new_feature_score (self , estimator , X , y , cv , current_mask ):
295
+ def _get_best_new_feature_score (self , estimator , X , y , cv , current_mask , ** params ):
273
296
# Return the best new feature and its score to add to the current_mask,
274
297
# i.e. return the best new feature and its score to add (resp. remove)
275
298
# when doing forward selection (resp. backward selection).
@@ -290,6 +313,7 @@ def _get_best_new_feature_score(self, estimator, X, y, cv, current_mask):
290
313
cv = cv ,
291
314
scoring = self .scoring ,
292
315
n_jobs = self .n_jobs ,
316
+ params = params ,
293
317
).mean ()
294
318
new_feature_idx = max (scores , key = lambda feature_idx : scores [feature_idx ])
295
319
return new_feature_idx , scores [new_feature_idx ]
@@ -302,3 +326,32 @@ def _more_tags(self):
302
326
return {
303
327
"allow_nan" : _safe_tags (self .estimator , key = "allow_nan" ),
304
328
}
329
+
330
+ def get_metadata_routing (self ):
331
+ """Get metadata routing of this object.
332
+
333
+ Please check :ref:`User Guide <metadata_routing>` on how the routing
334
+ mechanism works.
335
+
336
+ .. versionadded:: 1.6
337
+
338
+ Returns
339
+ -------
340
+ routing : MetadataRouter
341
+ A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
342
+ routing information.
343
+ """
344
+ router = MetadataRouter (owner = self .__class__ .__name__ )
345
+ router .add (
346
+ estimator = self .estimator ,
347
+ method_mapping = MethodMapping ().add (caller = "fit" , callee = "fit" ),
348
+ )
349
+ router .add (
350
+ splitter = check_cv (self .cv , classifier = is_classifier (self .estimator )),
351
+ method_mapping = MethodMapping ().add (caller = "fit" , callee = "split" ),
352
+ )
353
+ router .add (
354
+ scorer = check_scoring (self .estimator , scoring = self .scoring ),
355
+ method_mapping = MethodMapping ().add (caller = "fit" , callee = "score" ),
356
+ )
357
+ return router
0 commit comments