8000 using a support_ mask · scikit-learn/scikit-learn@86ace7c · GitHub
[go: up one dir, main page]

Skip to content

Commit 86ace7c

Browse files
committed
using a support_ mask
1 parent ee89b08 commit 86ace7c

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

sklearn/feature_selection/sequential_feature_selector.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ class SequentialFeatureSelector(BaseEstimator, MetaEstimatorMixin,
8282
8383
Attributes
8484
----------
85-
feature_subset_idx_ : array-like, shape = [n_predictions]
86-
Feature Indices of the selected feature subsets.
85+
support_ : array of shape [n_features]
86+
The mask of selected features.
8787
8888
score_ : float
8989
Cross validation average score of the selected subset.
@@ -221,7 +221,7 @@ def fit(self, X, y):
221221
self.subsets_[k] = {
222222
'feature_subset_idx': k_idx,
223223
'cv_scores': k_score,
224-
'avg_score': np.nanmean(k_score)
224+
'avg_score': np.mean(k_score)
225225
}
226226

227227
best_subset = None
@@ -266,7 +266,8 @@ def fit(self, X, y):
266266
k_score = max_score
267267
k_idx = self.subsets_[best_subset]['feature_subset_idx']
268268

269-
self.feature_subset_idx_ = k_idx
269+
self.support_ = k_idx
270+
self.support_ = self._get_support_mask()
270271
self.score_ = k_score
271272
return self
272273

@@ -324,8 +325,8 @@ def _exclusion(self, feature_set, X, y, estimator, fixed_feature=None):
324325
return res
325326

326327
def _get_support_mask(self):
327-
check_is_fitted(self, 'feature_subset_idx_')
328+
check_is_fitted(self, 'support_')
328329
mask = np.zeros((self._n_features,), dtype=np.bool)
329330 # list to avoid IndexError in old NumPy versions
330-
mask[list(self.feature_subset_idx_)] = True
331+
mask[list(self.support_)] = True
331332
return mask

sklearn/feature_selection/tests/test_sequential_feature_selector.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
import numpy as np
55
from numpy.testing import assert_almost_equal
6+
from numpy.testing import assert_array_equal
67
from sklearn.utils.testing import assert_raise_message
78
from sklearn.neighbors import KNeighborsClassifier
89
from sklearn.linear_model import LinearRegression
@@ -37,7 +38,8 @@ def test_run_default():
3738
knn = KNeighborsClassifier()
3839
sfs = SFS(estimator=knn)
3940
sfs.fit(X, y)
40-
assert sfs.feature_subset_idx_ == (3, )
41+
assert_array_equal(sfs.support_,
42+
np.array([False, False, False, True]))
4143

4244

4345
def test_kfeatures_type_1():
@@ -163,7 +165,8 @@ def test_knn_option_sbs():
163165
forward=False,
164166
cv=4)
165167
sfs3 = sfs3.fit(X, y)
166-
assert sfs3.feature_subset_idx_ == (1, 2, 3)
168+
assert_array_equal(sfs3.support_,
169+
np.array([False, True, True, True]))
167170

168171

169172
def test_knn_option_sfs_tuplerange():
@@ -177,7 +180,8 @@ def test_knn_option_sfs_tuplerange():
177180
cv=4)
178181
sfs4 = sfs4.fit(X, y)
179182
assert round(sfs4.score_, 3) == 0.967
180-
assert sfs4.feature_subset_idx_ == (0, 2, 3)
183+
assert_array_equal(sfs4.support_,
184+
np.array([True, False, True, True]))
181185

182186

183187
def test_knn_scoring_metric():
@@ -217,7 +221,7 @@ def test_regression():
217221
forward=True,
218222
cv=10)
219223
sfs_r = sfs_r.fit(X, y)
220-
assert len(sfs_r.feature_subset_idx_) == 13
224+
assert sum(sfs_r.support_) == 13
221225
assert round(sfs_r.score_, 4) == 0.2001
222226

223227

@@ -233,7 +237,7 @@ def test_regression_in_tuplerange_forward():
233237
forward=True,
234238
cv=10)
235239
sfs_r = sfs_r.fit(X, y)
236-
assert len(sfs_r.feature_subset_idx_) == 9
240+
assert sum(sfs_r.support_) == 9
237241
assert round(sfs_r.score_, 4) == 0.2991, sfs_r.score_
238242

239243

@@ -252,7 +256,12 @@ def test_regression_in_tuplerange_backward():
252256
cv=10)
253257

254258
sfs_r = sfs_r.fit(X, y)
255-
assert len(sfs_r.feature_subset_idx_) == 5
259+
print(sfs_r.support_)
260+
print(type(sfs_r.support_))
261+
assert_array_equal(sfs_r.support_,
262+
np.array([False, False, False, False,
263+
True, False, False, True,
264+
True, False, True, False, True]))
256265

257266

258267
def test_transform_not_fitted():

0 commit comments

Comments
 (0)
0