From 528533dab7bbf0a6ba6b646bae486b8639321da9 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Mon, 21 Mar 2016 01:16:04 -0400 Subject: [PATCH] MAINT: Simplify n_features_to_select in RFECV --- sklearn/feature_selection/rfe.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/sklearn/feature_selection/rfe.py b/sklearn/feature_selection/rfe.py index 5df14b9037ca0..5eef0ac659267 100644 --- a/sklearn/feature_selection/rfe.py +++ b/sklearn/feature_selection/rfe.py @@ -421,14 +421,11 @@ def fit(self, X, y): func(rfe, self.estimator, X, y, train, test, scorer) for train, test in cv.split(X, y)) - scores = np.sum(scores, axis=0)[::-1] - # The index in 'scores' when 'n_features' features are selected - n_feature_index = np.ceil((n_features - n_features_to_select) / - float(self.step)) - n_features_to_select = max(n_features_to_select, - n_features - ((n_feature_index - - np.argmax(scores)) * - self.step)) + scores = np.sum(scores, axis=0) + n_features_to_select = max( + n_features - (np.argmax(scores) * self.step), + n_features_to_select) + # Re-execute an elimination with best_k over the whole set rfe = RFE(estimator=self.estimator, n_features_to_select=n_features_to_select, step=self.step) @@ -444,5 +441,5 @@ def fit(self, X, y): # Fixing a normalization error, n is equal to get_n_splits(X, y) - 1 # here, the scores are normalized by get_n_splits(X, y) - self.grid_scores_ = scores / cv.get_n_splits(X, y) + self.grid_scores_ = scores[::-1] / cv.get_n_splits(X, y) return self