8000 ENH: Small improvements/fixes to mutual info feature selection · scikit-learn/scikit-learn@562a8e8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 562a8e8

Browse files
author
Nikolay Mayorov
committed
ENH: Small improvements/fixes to mutual info feature selection
1 parent 96a9a81 commit 562a8e8

File tree

3 files changed

+27
-20
lines changed

3 files changed

+27
-20
lines changed

sklearn/feature_selection/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717

1818
from .variance_threshold import VarianceThreshold
1919

20-
from .multivariate_filtering import MinRedundancyMaxRelevance
21-
2220
from .rfe import RFE
2321
from .rfe import RFECV
2422

23+
from .mutual_info import MinRedundancyMaxRelevance
24+
2525
__all__ = ['GenericUnivariateSelect',
2626
'MinRedundancyMaxRelevance',
2727
'RFE',

sklearn/feature_selection/multivariate_filtering.py renamed to sklearn/feature_selection/mutual_info.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,12 @@ class MinRedundancyMaxRelevance(BaseEstimator, SelectorMixin):
108108
109109
Parameters
110110
----------
111-
n_features_to_select : None or int, optional (default=None)
112-
Number of features to select. If None, half of the features
113-
will be selected.
111+
n_features_to_select : float or int, optional (default=0.5)
112+
Number of features to select. The value greater than or equal 1 is
113+
interpreted as the absolute number of features to select. The value
114+
within (0.0, 1.0) is interpreted as the percentage from the initial
115+
number of features (rounded down). Half of the features is selected by
116+
default.
114117
categorical_features : bool or array_like with shape (n_features),
115118
optional (default=False)
116119
If bool, then determines whether to consider all features categorical
@@ -124,8 +127,6 @@ class MinRedundancyMaxRelevance(BaseEstimator, SelectorMixin):
124127
125128
Attributes
126129
----------
127-
n_features_ : int
128-
Number of selected features.
129130
support_ : ndarray, shape (n_features,)
130131
Mask of selected features.
131132
relevance_ : ndarray, shape (n_features,)
@@ -147,7 +148,7 @@ class MinRedundancyMaxRelevance(BaseEstimator, SelectorMixin):
147148
.. [3] B. C. Ross "Mutual Information between Discrete and Continuous
148149
Data Sets". PLoS ONE 9(2), 2014.
149150
"""
150-
def __init__(self, n_features_to_select=None, categorical_features=False,
151+
def __init__(self, n_features_to_select=0.5, categorical_features=False,
151152
categorical_target=False, n_neighbors=3):
152153
self.n_features_to_select = n_features_to_select
153154
self.categorical_features = categorical_features
@@ -171,9 +172,19 @@ def fit(self, X, y):
171172
-------
172173
self
173174
"""
174-
X, y = check_X_y(X, y, accept_sparse='csc')
175+
X, y = check_X_y(X, y, accept_sparse='csc',
176+
y_numeric=not self.categorical_target)
175177

176178
n_features = X.shape[1]
179+
180+
if self.n_features_to_select >= 1:
181+
n_features_to_select = int(self.n_features_to_select)
182+
elif 0 < self.n_features_to_select < 1:
183+
n_features_to_select = max(
184+
1, int(self.n_features_to_select * n_features))
185+
else:
186+
raise ValueError("`n_features_to_select` must be positive.")
187+
177188
if isinstance(self.categorical_features, bool):
178189
categorical_features = np.empty(n_features, dtype=bool)
179190
categorical_features.fill(self.categorical_features)
@@ -203,14 +214,9 @@ def fit(self, X, y):
203214
xi, xj, categorical_features[i], categorical_features[j],
204215
self.n_neighbors)
205216

206-
if self.n_features_to_select is None:
207-
self.n_features_ = (n_features + 1) // 2
208-
else:
209-
self.n_features_ = self.n_features_to_select
210-
211217
support = np.zeros(n_features, dtype=bool)
212218
support[np.argmax(relevance)] = True
213-
for i in range(self.n_features_ - 1):
219+
for i in range(n_features_to_select - 1):
214220
selected = np.nonzero(support)[0]
215221
candidates = np.nonzero(~support)[0]
216222
D = relevance[candidates]

sklearn/feature_selection/tests/test_multivariate_filtering.py renamed to sklearn/feature_selection/tests/test_mutual_info.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sklearn.utils.testing import (
88
assert_array_equal, assert_almost_equal, assert_true)
99
from sklearn.feature_selection import MinRedundancyMaxRelevance
10-
from sklearn.feature_selection.multivariate_filtering import _compute_mi
10+
from sklearn.feature_selection.mutual_info import _compute_mi
1111

1212

1313
class TestMIComputation(object):
@@ -108,8 +108,9 @@ def test_categorical(self):
108108
# (thus redundant) and x[3] is weekly informative. So the algorithm
109109
# should pick features 0 and 2.
110110

111-
m = MinRedundancyMaxRelevance(categorical_features=True,
112-
categorical_target=True)
111+
m = MinRedundancyMaxRelevance(
112+
categorical_features=True, categorical_target=True,
113+
n_features_to_select=2)
113114
m.fit(X, y)
114115
assert_array_equal(m.support_, np.array([True, False, True, False]))
115116

@@ -136,7 +137,7 @@ def test_continuous(self):
136137

137138
y = Z[:, 0]
138139
X = Z[:, 1:]
139-
m = MinRedundancyMaxRelevance()
140+
m = MinRedundancyMaxRelevance(n_features_to_select=2)
140141
m.fit(X, y)
141142
assert_array_equal(m.support_, np.array([True, False, True]))
142143

@@ -148,7 +149,7 @@ def test_mixed(self):
148149
X[:, 2] = X[:, 2] > 0.5
149150

150151
m = MinRedundancyMaxRelevance(
151-
categorical_features=[False, False, True],
152+
categorical_features=[False, False, True], n_features_to_select=2,
152153
categorical_target=True)
153154
m.fit(X, y)
154155
assert_array_equal(m.support_, np.array([True, False, True]))

0 commit comments

Comments
 (0)
0