8000 FEA Add metadata routing for SequentialFeatureSelector (#29260) · scikit-learn/scikit-learn@234260d · GitHub
[go: up one dir, main page]

Skip to content

Commit 234260d

Browse files
OmarManzooradam2392glemaitre
authored
FEA Add metadata routing for SequentialFeatureSelector (#29260)
Co-authored-by: Adam Li <adam2392@gmail.com> Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
1 parent 06eafd8 commit 234260d

File tree

5 files changed

+88
-10
lines changed

5 files changed

+88
-10
lines changed

doc/metadata_routing.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ Meta-estimators and functions supporting metadata routing:
285285
- :class:`sklearn.ensemble.BaggingClassifier`
286286
- :class:`sklearn.ensemble.BaggingRegressor`
287287
- :class:`sklearn.feature_selection.SelectFromModel`
288+
- :class:`sklearn.feature_selection.SequentialFeatureSelector`
288289
- :class:`sklearn.impute.IterativeImputer`
289290
- :class:`sklearn.linear_model.ElasticNetCV`
290291
- :class:`sklearn.linear_model.LarsCV`
@@ -324,4 +325,3 @@ Meta-estimators and tools not supporting metadata routing yet:
324325
- :class:`sklearn.ensemble.AdaBoostRegressor`
325326
- :class:`sklearn.feature_selection.RFE`
326327
- :class:`sklearn.feature_selection.RFECV`
327-
- :class:`sklearn.feature_selection.SequentialFeatureSelector`

doc/whats_new/v1.6.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ more details.
8383
params to the underlying regressor.
8484
:pr:`29136` by :user:`Omar Salman <OmarManzoor>`.
8585

86+
- |Feature| :class:`feature_selection.SequentialFeatureSelector` now supports
87+
metadata routing in its `fit` method and passes the corresponding params to
88+
the :func:`model_selection.cross_val_score` function.
89+
:pr:`29260` by :user:`Omar Salman <OmarManzoor>`.
90+
8691
- |Feature| :func:`model_selection.validation_curve` now supports metadata routing for
8792
the `fit` method of its estimator and for its underlying CV splitter and scorer.
8893
:pr:`29329` by :user:`Stefanie Senger <StefanieSenger>`.

sklearn/feature_selection/_sequential.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,22 @@
1010
import numpy as np
1111

1212
from ..base import BaseEstimator, MetaEstimatorMixin, _fit_context, clone, is_classifier
13-
from ..metrics import get_scorer_names
13+
from ..metrics import check_scoring, get_scorer_names
1414
from ..model_selection import check_cv, cross_val_score
15+
from ..utils._metadata_requests import (
16+
MetadataRouter,
17+
MethodMapping,
18+
_raise_for_params,
19+
_routing_enabled,
20+
process_routing,
21+
)
1522
from ..utils._param_validation import HasMethods, Interval, RealNotInt, StrOptions
1623
from ..utils._tags import _safe_tags
17-
from ..utils.metadata_routing import _RoutingNotSupportedMixin
1824
from ..utils.validation import check_is_fitted
1925
from ._base import SelectorMixin
2026

2127

22-
class SequentialFeatureSelector(
23-
_RoutingNotSupportedMixin, SelectorMixin, MetaEstimatorMixin, BaseEstimator
24-
):
28+
class SequentialFeatureSelector(SelectorMixin, MetaEstimatorMixin, BaseEstimator):
2529
"""Transformer that performs Sequential Feature Selection.
2630
2731
This Sequential Feature Selector adds (forward selection) or
@@ -191,7 +195,7 @@ def __init__(
191195
# SequentialFeatureSelector.estimator is not validated yet
192196
prefer_skip_nested_validation=False
193197
)
194-
def fit(self, X, y=None):
198+
def fit(self, X, y=None, **params):
195199
"""Learn the features to select from X.
196200
197201
Parameters
@@ -204,11 +208,24 @@ def fit(self, X, y=None):
204208
Target values. This parameter may be ignored for
205209
unsupervised learning.
206210
211+
**params : dict, default=None
212+
Parameters to be passed to the underlying `estimator`, `cv`
213+
and `scorer` objects.
214+
215+
.. versionadded:: 1.6
216+
217+
Only available if `enable_metadata_routing=True`,
218+
which can be set by using
219+
``sklearn.set_config(enable_metadata_routing=True)``.
220+
See :ref:`Metadata Routing User Guide <metadata_routing>` for
221+
more details.
222+
207223
Returns
208224
-------
209225
self : object
210226
Returns the instance itself.
211227
"""
228+
_raise_for_params(params, self, "fit")
212229
tags = self._get_tags()
213230
X = self._validate_data(
214231
X,
@@ -251,9 +268,15 @@ def fit(self, X, y=None):
251268

252269
old_score = -np.inf
253270
is_auto_select = self.tol is not None and self.n_features_to_select == "auto"
271+
272+
# We only need to verify the routing here and not use the routed params
273+
# because internally the actual routing will also take place inside the
274+
# `cross_val_score` function.
275+
if _routing_enabled():
276+
process_routing(self, "fit", **params)
254277
for _ in range(n_iterations):
255278
new_feature_idx, new_score = self._get_best_new_feature_score(
256-
cloned_estimator, X, y, cv, current_mask
279+
cloned_estimator, X, y, cv, current_mask, **params
257280
)
258281
if is_auto_select and ((new_score - old_score) < self.tol):
259282
break
@@ -269,7 +292,7 @@ def fit(self, X, y=None):
269292

270293
return self
271294

272-
def _get_best_new_feature_score(self, estimator, X, y, cv, current_mask):
295+
def _get_best_new_feature_score(self, estimator, X, y, cv, current_mask, **params):
273296
# Return the best new feature and its score to add to the current_mask,
274297
# i.e. return the best new feature and its score to add (resp. remove)
275298
# when doing forward selection (resp. backward selection).
@@ -290,6 +313,7 @@ def _get_best_new_feature_score(self, estimator, X, y, cv, current_mask):
290313
cv=cv,
291314
scoring=self.scoring,
292315
n_jobs=self.n_jobs,
316+
params=params,
293317
).mean()
294318
new_feature_idx = max(scores, key=lambda feature_idx: scores[feature_idx])
295319
return new_feature_idx, scores[new_feature_idx]
@@ -302,3 +326,32 @@ def _more_tags(self):
302326
return {
303327
"allow_nan": _safe_tags(self.estimator, key="allow_nan"),
304328
}
329+
330+
def get_metadata_routing(self):
331+
"""Get metadata routing of this object.
332+
333+
Please check :ref:`User Guide <metadata_routing>` on how the routing
334+
mechanism works.
335+
336+
.. versionadded:: 1.6
337+
338+
Returns
339+
-------
340+
routing : MetadataRouter
341+
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
342+
routing information.
343+
"""
344+
router = MetadataRouter(owner=self.__class__.__name__)
345+
router.add(
346+
estimator=self.estimator,
347+
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
348+
)
349+
router.add(
350+
splitter=check_cv(self.cv, classifier=is_classifier(self.estimator)),
351+
method_mapping=MethodMapping().add(caller="fit", callee="split"),
352+
)
353+
router.add(
354+
scorer=check_scoring(self.estimator, scoring=self.scoring),
355+
method_mapping=MethodMapping().add(caller="fit", callee="score"),
356+
)
357+
return router

sklearn/feature_selection/tests/test_sequential.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,3 +321,12 @@ def test_cv_generator_support():
321321

322322
sfs = SequentialFeatureSelector(knc, n_features_to_select=5, cv=splits)
323323
sfs.fit(X, y)
324+
325+
326+
def test_fit_rejects_params_with_no_routing_enabled():
327+
X, y = make_classification(random_state=42)
328+
est = LinearRegression()
329+
sfs = SequentialFeatureSelector(estimator=est)
330+
331+
with pytest.raises(ValueError, match="is only supported if"):
332+
sfs.fit(X, y, sample_weight=np.ones_like(y))

sklearn/tests/test_metaestimators_metadata_routing.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,18 @@ def enable_slep006():
407407
],
408408
"method_mapping": {"fit": ["fit", "score"]},
409409
},
410+
{
411+
"metaestimator": SequentialFeatureSelector,
412+
"estimator_name": "estimator",
413+
"estimator": "classifier",
414+
"X": X,
415+
"y": y,
416+
"estimator_routing_methods": ["fit"],
417+
"scorer_name": "scoring",
418+
"scorer_routing_methods": ["fit"],
419+
"cv_name": "cv",
420+
"cv_routing_methods": ["fit"],
421+
},
410422
]
411423
"""List containing all metaestimators to be tested and their settings
412424
@@ -450,7 +462,6 @@ def enable_slep006():
450462
AdaBoostRegressor(),
451463
RFE(ConsumingClassifier()),
452464
RFECV(ConsumingClassifier()),
453-
SequentialFeatureSelector(ConsumingClassifier()),
454465
]
455466

456467

0 commit comments

Comments
 (0)
0