8000 [MRG+1] add groups support to RFECV (#9656) · scikit-learn/scikit-learn@48b82a6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 48b82a6

Browse files
adamgreenhallTomDLT
authored andcommitted
[MRG+1] add groups support to RFECV (#9656)
1 parent 8ebb9a9 commit 48b82a6

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

sklearn/feature_selection/rfe.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def __init__(self, estimator, step=1, cv=None, scoring=None, verbose=0,
390390
self.verbose = verbose
391391
self.n_jobs = n_jobs
392392

393-
def fit(self, X, y):
393+
def fit(self, X, y, groups=None):
394394
"""Fit the RFE model and automatically tune the number of selected
395395
features.
396396
@@ -403,6 +403,10 @@ def fit(self, X, y):
403403
y : array-like, shape = [n_samples]
404404
Target values (integers for classification, real numbers for
405405
regression).
406+
407+
groups : array-like, shape = [n_samples], optional
408+
Group labels for the samples used while splitting the dataset into
409+
train/test set.
406410
"""
407411
X, y = check_X_y(X, y, "csr")
408412

@@ -442,7 +446,7 @@ def fit(self, X, y):
442446

443447
scores = parallel(
444448
func(rfe, self.estimator, X, y, train, test, scorer)
445-
for train, test in cv.split(X, y))
449+
for train, test in cv.split(X, y, groups))
446450

447451
scores = np.sum(scores, axis=0)
448452
n_features_to_select = max(
@@ -465,5 +469,5 @@ def fit(self, X, y):
465469

466470
# Fixing a normalization error, n is equal to get_n_splits(X, y) - 1
467471
# here, the scores are normalized by get_n_splits(X, y)
468-
self.grid_scores_ = scores[::-1] / cv.get_n_splits(X, y)
472+
self.grid_scores_ = scores[::-1] / cv.get_n_splits(X, y, groups)
469473
return self

sklearn/feature_selection/tests/test_rfe.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sklearn.svm import SVC, SVR
1212
from sklearn.ensemble import RandomForestClassifier
1313
from sklearn.model_selection import cross_val_score
14+
from sklearn.model_selection import GroupKFold
1415

1516
from sklearn.utils import check_random_state
1617
from sklearn.utils.testing import ignore_warnings
@@ -328,3 +329,21 @@ def test_rfe_cv_n_jobs():
328329
rfecv.fit(X, y)
329330
assert_array_almost_equal(rfecv.ranking_, rfecv_ranking)
330331
assert_array_almost_equal(rfecv.grid_scores_, rfecv_grid_scores)
332+
333+
334+
def test_rfe_cv_groups():
335+
generator = check_random_state(0)
336+
iris = load_iris()
337+
number_groups = 4
338+
groups = np.floor(np.linspace(0, number_groups, len(iris.target)))
339+
X = iris.data
340+
y = (iris.target > 0).astype(int)
341+
342+
est_groups = RFECV(
343+
estimator=RandomForestClassifier(random_state=generator),
344+
step=1,
345+
scoring='accuracy',
346+
cv=GroupKFold(n_splits=2)
347+
)
348+
est_groups.fit(X, y, groups=groups)
349+
assert est_groups.n_features_ > 0

0 commit comments

Comments
 (0)
0