8000 ENH Allow `sample_weight` and other `fit_params` in RFE (#20380) · scikit-learn/scikit-learn@eb901df · GitHub
[go: up one dir, main page]

Skip to content

Commit eb901df

Browse files
fbiduglemaitreadrinjalali
authored
ENH Allow sample_weight and other fit_params in RFE (#20380)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
1 parent 39f37bb commit eb901df

File tree

3 files changed

+48
-7
lines changed

3 files changed

+48
-7
lines changed

doc/whats_new/v1.0.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,11 @@ Changelog
349349
when the variance threshold is negative.
350350
:pr:`20207` by :user:`Tomohiro Endo <europeanplaice>`
351351

352+
- |Enhancement| :func:`feature_selection.RFE.fit` accepts additional estimator
353+
parameters that are passed directly to the estimator's `fit` me 8000 thod.
354+
:pr:`20380` by :user:`Iván Pulido <ijpulidos>`, :user:`Felipe Bidu <fbidu>`,
355+
:user:`Gil Rutter <g-rutter>`, and :user:`Adrin Jalali <adrinjalali>`.
356+
352357
- |FIX| Fix a bug in :func:`isotonic.isotonic_regression` where the
353358
`sample_weight` passed by a user were overwritten during the fit.
354359
:pr:`20515` by :user:`Carsten Allefeld <allefeld>`.

sklearn/feature_selection/_rfe.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def classes_(self):
192192
"""
193193
return self.estimator_.classes_
194194

195-
def fit(self, X, y):
195+
def fit(self, X, y, **fit_params):
196196
"""Fit the RFE model and then the underlying estimator on the selected features.
197197
198198
Parameters
@@ -203,14 +203,18 @@ def fit(self, X, y):
203203
y : array-like of shape (n_samples,)
204204
The target values.
205205
206+
**fit_params : dict
207+
Additional parameters passed to the `fit` method of the underlying
208+
estimator.
209+
206210
Returns
207211
-------
208212
self : object
209213
Fitted estimator.
210214
"""
211-
return self._fit(X, y)
215+
return self._fit(X, y, **fit_params)
212216

213-
def _fit(self, X, y, step_score=None):
217+
def _fit(self, X, y, step_score=None, **fit_params):
214218
# Parameter step_score controls the calculation of self.scores_
215219
# step_score is not exposed to users
216220
# and is used when implementing RFECV
@@ -269,7 +273,7 @@ def _fit(self, X, y, step_score=None):
269273
if self.verbose > 0:
270274
print("Fitting estimator with %d features." % np.sum(support_))
271275

272-
estimator.fit(X[:, features], y)
276+
estimator.fit(X[:, features], y, **fit_params)
273277

274278
# Get importance and rank them
275279
importances = _get_feature_importances(
@@ -296,7 +300,7 @@ def _fit(self, X, y, step_score=None):
296300
# Set final attributes
297301
features = np.arange(n_features)[support_]
298302
self.estimator_ = clone(self.estimator)
299-
self.estimator_.fit(X[:, features], y)
303+
self.estimator_.fit(X[:, features], y, **fit_params)
300304

301305
# Compute step score when only n_features_to_select features left
302306
if step_score:
@@ -325,7 +329,7 @@ def predict(self, X):
325329
return self.estimator_.predict(self.transform(X))
326330

327331
@if_delegate_has_method(delegate="estimator")
328-
def score(self, X, y):
332+
def score(self, X, y, **fit_params):
329333
"""Reduce X to the selected features and return the score of the underlying estimator.
330334
331335
Parameters
@@ -336,14 +340,20 @@ def score(self, X, y):
336340
y : array of shape [n_samples]
337341
The target values.
338342
343+
**fit_params : dict
344+
Parameters to pass to the `score` method of the underlying
345+
estimator.
346+
347+
.. versionadded:: 1.0
348+
339349
Returns
340350
-------
341351
score : float
342352
Score of the underlying base estimator computed with the selected
343353
features returned by `rfe.transform(X)` and `y`.
344354
"""
345355
check_is_fitted(self)
346-
return self.estimator_.score(self.transform(X), y)
356+
return self.estimator_.score(self.transform(X), y, **fit_params)
347357

348358
def _get_support_mask(self):
349359
check_is_fitted(self)

sklearn/feature_selection/tests/test_rfe.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from numpy.testing import assert_array_almost_equal, assert_array_equal
99
from scipy import sparse
1010

11+
from sklearn.base import BaseEstimator, ClassifierMixin
1112
from sklearn.feature_selection import RFE, RFECV
1213
from sklearn.datasets import load_iris, make_friedman1
1314
from sklearn.metrics import zero_one_loss
@@ -108,6 +109,31 @@ def test_rfe():
108109
assert_array_almost_equal(X_r, X_r_sparse.toarray())
109110

110111

112+
def test_RFE_fit_score_params():
113+
# Make sure RFE passes the metadata down to fit and score methods of the
114+
# underlying estimator
115+
class TestEstimator(BaseEstimator, ClassifierMixin):
116+
def fit(self, X, y, prop=None):
117+
if prop is None:
118+
raise ValueError("fit: prop cannot be None")
119+
self.svc_ = SVC(kernel="linear").fit(X, y)
120+
self.coef_ = self.svc_.coef_
121+
return self
122+
123+
def score(self, X, y, prop=None):
124+
if prop is None:
125+
raise ValueError("score: prop cannot be None")
126+
return self.svc_.score(X, y)
127+
128+
X, y = load_iris(return_X_y=True)
129+
with pytest.raises(ValueError, match="fit: prop cannot be None"):
130+
RFE(estimator=TestEstimator()).fit(X, y)
131+
with pytest.raises(ValueError, match="score: prop cannot be None"):
132+
RFE(estimator=TestEstimator()).fit(X, y, prop="foo").score(X, y)
133+
134+
RFE(estimator=TestEstimator()).fit(X, y, prop="foo").score(X, y, prop="foo")
135+
136+
111137
@pytest.mark.parametrize("n_features_to_select", [-1, 2.1])
112138
def test_rfe_invalid_n_features_errors(n_features_to_select):
113139
clf = SVC(kernel="linear")

0 commit comments

Comments
 (0)
0