8000 Select k-best features in SelectFromModel · scikit-learn/scikit-learn@41b9679 · GitHub
[go: up one dir, main page]

Skip to content

Commit 41b9679

Browse files
committed
Select k-best features in SelectFromModel
1 parent eed5fc5 commit 41b9679

File tree

2 files changed

+147
-4
lines changed

2 files changed

+147
-4
lines changed

sklearn/feature_selection/from_model.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ..base import TransformerMixin, BaseEstimator, clone
88
from ..externals import six
99

10-
from ..utils import safe_mask, check_array, deprecated
10+
from ..utils import safe_mask, check_array, deprecated, check_X_y
1111
from ..utils.validation import check_is_fitted
1212
from ..exceptions import NotFittedError
1313

@@ -173,6 +173,10 @@ class SelectFromModel(BaseEstimator, SelectorMixin):
173173
Otherwise train the model using ``fit`` and then ``transform`` to do
174174
feature selection.
175175
176+
max_features : int, between 0 and number of features, default None.
177+
If provided, first n most important features are
178+
kept, ignoring threshold parameter.
179+
176180
Attributes
177181
----------
178182
`estimator_`: an estimator
@@ -183,10 +187,28 @@ class SelectFromModel(BaseEstimator, SelectorMixin):
183187
`threshold_`: float
184188
The threshold value used for feature selection.
185189
"""
186-
def __init__(self, estimator, threshold=None, prefit=False):
190+
def __init__(self, estimator, threshold=None, prefit=False,
191+
max_features=None):
187192
self.estimator = estimator
188193
self.threshold = threshold
189194
self.prefit = prefit
195+
self.max_features = max_features
196+
197+
def _check_max_features(self, X, max_features):
198+
if self.max_features is not None:
199+
if isinstance(self.max_features, int):
200+
if 0 <= self.max_features <= X.shape[1]:
201+
return
202+
elif self.max_features == 'all':
203+
return
204+
raise ValueError(
205+
"max_features should be >=0, <= n_features; got %r."
206+
" Use max_features='all' to return all features."
207+
% self.max_features)
208+
209+
def _check_params(self, X, y):
210+
X, y = check_X_y(X, y)
211+
self._check_max_features(X, self.max_features)
190212

191213
def _get_support_mask(self):
192214
# SelectFromModel can directly call on transform.
@@ -201,7 +223,15 @@ def _get_support_mask(self):
201223
scores = _get_feature_importances(estimator)
202224
self.threshold_ = _calculate_threshold(estimator, scores,
203225
self.threshold)
204-
return scores >= self.threshold_
226+
mask = np.zeros_like(scores, dtype=bool)
227+
if self.max_features == 'all':
228+
self.max_features = scores.size
229+
candidate_indices = np.argsort(-scores,
230+
kind='mergesort')[:self.max_features]
231+
mask[candidate_indices] = True
232+
if self.threshold is not None:
233+
mask[scores < self.threshold_] = False
234+
return mask
205235

206236
def fit(self, X, y=None, **fit_params):
207237
"""Fit the SelectFromModel meta-transformer.
@@ -222,6 +252,7 @@ def fit(self, X, y=None, **fit_params):
222252
self : object
223253
Returns self.
224254
"""
255+
self._check_params(X, y)
225256
if self.prefit:
226257
raise NotFittedError(
227258
"Sinc 628C e 'prefit=True', call transform directly")

sklearn/feature_selection/tests/test_from_model.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
from sklearn.utils.testing import assert_almost_equal
1111
from sklearn.utils.testing import assert_warns
1212
from sklearn.utils.testing import skip_if_32bit
13+
from sklearn.utils.testing import assert_equal
1314

1415
from sklearn import datasets
1516
from sklearn.linear_model import LogisticRegression, SGDClassifier, Lasso
1617
from sklearn.svm import LinearSVC
1718
from sklearn.feature_selection import SelectFromModel
1819
from sklearn.ensemble import RandomForestClassifier
1920
from sklearn.linear_model import PassiveAggressiveClassifier
21+
from sklearn.base import BaseEstimator
2022

2123
iris = datasets.load_iris()
2224
data, y = iris.data, iris.target
@@ -63,6 +65,112 @@ def test_input_estimator_unchanged():
6365
assert_true(transformer.estimator is est)
6466

6567

68+
def check_invalid_max_features(est, X, y):
69+
max_features = X.shape[1]
70+
for invalid_max_n_feature in [-1, max_features + 1, 'gobbledigook']:
71+
transformer = SelectFromModel(estimator=est,
72+
max_features=invalid_max_n_feature)
73+
assert_raises(ValueError, transformer.fit, X, y)
74+
75+
76+
def check_valid_max_features(est, X, y):
77+
max_features = X.shape[1]
78+
for valid_max_n_feature in [0, max_features, 'all', 5]:
79+
transformer = SelectFromModel(estimator=est,
80+
max_features=valid_max_n_feature)
81+
X_new = transformer.fit_transform(X, y)
82+
if valid_max_n_feature == 'all':
83+
valid_max_n_feature = max_features
84+
assert_equal(X_new.shape[1], valid_max_n_feature)
85+
86+
87+
class FixedImportanceEstimator(BaseEstimator):
88+ F438
def __init__(self, importances):
89+
self.importances = importances
90+
91+
def fit(self, X, y=None):
92+
self.feature_importances_ = np.array(self.importances)
93+
94+
95+
def check_max_features(est, X, y):
96+
X = X.copy()
97+
max_features = X.shape[1]
98+
99+
check_valid_max_features(est, X, y)
100+
check_invalid_max_features(est, X, y)
101+
102+
transformer1 = SelectFromModel(estimator=est, max_features='all')
103+
transformer2 = SelectFromModel(estimator=est,
104+
max_features=max_features)
105+
X_new1 = transformer1.fit_transform(X, y)
106+
X_new2 = transformer2.fit_transform(X, y)
107+
assert_array_equal(X_new1, X_new2)
108+
109+
# Test max_features against actual model.
110+
111+
transformer1 = SelectFromModel(estimator=Lasso(alpha=0.025))
112+
X_new1 = transformer1.fit_transform(X, y)
113+
for n_features in range(1, X_new1.shape[1] + 1):
114+
transformer2 = SelectFromModel(estimator=Lasso(alpha=0.025),
115+
max_features=n_features)
116+
X_new2 = transformer2.fit_transform(X, y)
117+
assert_array_equal(X_new1[:, :n_features], X_new2)
118+
assert_array_equal(transformer1.estimator_.coef_,
119+
transformer2.estimator_.coef_)
120+
121+
# Test if max_features can break tie among feature importance
122+
123+
feature_importances = np.array([4, 4, 4, 4, 3, 3, 3, 2, 2, 1])
124+
for n_features in range(1, max_features + 1):
125+
transformer = SelectFromModel(
126+
FixedImportanceEstimator(feature_importances),
127+
max_features=n_features)
128+
X_new = transformer.fit_transform(X, y)
129+
selected_feature_indices = np.where(transformer._get_support_mask())[0]
130+
assert_array_equal(selected_feature_indices, np.arange(n_features))
131+
assert_equal(X_new.shape[1], n_features)
132+
133+
134+
def check_threshold_and_max_features(est, X, y):
135+
transformer1 = SelectFromModel(estimator=est, max_features=3)
136+
X_new1 = transformer1.fit_transform(X, y)
137+
138+
transformer2 = SelectFromModel(estimator=est, threshold=0.04)
139+
X_new2 = transformer2.fit_transform(X, y)
140+
141+
transformer3 = SelectFromModel(estimator=est, max_features=3,
142+
threshold=0.04)
143+
X_new3 = transformer3.fit_transform(X, y)
144+
assert_equal(X_new3.shape[1], min(X_new1.shape[1], X_new2.shape[1]))
145+
selected_indices = \
146+
transformer3.transform(np.arange(X.shape[1]))[np.newaxis, :]
147+
assert_array_equal(X_new3, X[:, selected_indices[0][0]])
148+
149+
"""
150+
If threshold and max_features are not provided, all features are
151+
returned, use threshold=None if it is not required.
152+
"""
153+
transformer = SelectFromModel(estimator=Lasso(alpha=0.1))
154+
X_new = transformer.fit_transform(X, y)
155+
assert_array_equal(X, X_new)
156+
157+
transformer = SelectFromModel(estimator=Lasso(alpha=0.1), max_features=3)
158+
X_new = transformer.fit_transform(X, y)
159+
assert_equal(X_new.shape[1], 3)
160+
161+
# Threshold will be applied if it is not None
162+
transformer = SelectFromModel(estimator=Lasso(alpha=0.1), threshold=1e-5)
163+
X_new = transformer.fit_transform(X, y)
164+
mask = np.abs(transformer.estimator_.coef_) > 1e-5
165+
assert_array_equal(X_new, X[:, mask])
166+
167+
transformer = SelectFromModel(estimator=Lasso(alpha=0.1), threshold=1e-5,
168+
max_features=4)
169+
X_new = transformer.fit_transform(X, y)
170+
mask = np.abs(transformer.estimator_.coef_) > 1e-5
171+
assert_array_equal(X_new, X[:, mask])
172+
173+
66174
@skip_if_32bit
67175
def test_feature_importances():
68176
X, y = datasets.make_classification(
@@ -95,12 +203,16 @@ def test_feature_importances():
95203
assert_almost_equal(importances, importances_bis)
96204

97205
# For the Lasso and related models, the threshold defaults to 1e-5
98-
transformer = SelectFromModel(estimator=Lasso(alpha=0.1))
206+
transformer = SelectFromModel(estimator=Lasso(alpha=0.1), threshold=1e-5)
99207
transformer.fit(X, y)
100208
X_new = transformer.transform(X)
101209
mask = np.abs(transformer.estimator_.coef_) > 1e-5
102210
assert_array_equal(X_new, X[:, mask])
103211

212+
# Test max_features parameter using various values
213+
check_max_features(est, X, y)
214+
check_threshold_and_max_features(est, X, y)
215+
104216

105217
def test_partial_fit():
106218
est = PassiveAggressiveClassifier(random_state=0, shuffle=False)

0 commit comments

Comments
 (0)
0