8000 Solved pickle problem and added test · scikit-learn/scikit-learn@62cf55b · GitHub
[go: up one dir, main page]

Skip to content

Commit 62cf55b

Browse files
committed
Solved pickle problem and added test
1 parent ec17096 commit 62cf55b

File tree

2 files changed

+26 8000
-28
lines changed

2 files changed

+26
-28
lines changed

sklearn/feature_selection/multivariate_filtering.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def __init__(self, k=2, rule='diff'):
6363
"""
6464
self.k = k
6565
self.rule = rule
66-
self._rule_function = self._get_rule_function(rule)
6766

6867
def fit(self, X, y):
6968
"""
@@ -136,32 +135,26 @@ def _compute_mRMR(self, X, y):
136135
mask.append(ind)
137136
search_space.pop(ind)
138137

139-
fun = self._rule_function
140-
for m in range(0, self.k-1):
141-
tmp_score = fun(relevance[search_space],
142-
np.mean(redundancy[:, search_space].
143-
take(mask, axis=0), 0))
144-
score.append(max(tmp_score))
145-
ind = tmp_score.argmax(0)
146-
mask.append(search_space[ind])
147-
search_space.pop(ind)
148-
149-
return mask, score
150-
151-
def _get_rule_function(self, rule):
152-
"""
153-
Returns
154-
-------
155-
fun : function
156-
Function used to combine relevance (k) and redundancy (h) arrays
157-
"""
158-
if rule == 'diff':
159-
def fun(a, b):
160-
return a+b
161-
elif rule == 'prod':
162-
def fun(a, b):
163-
return a*b
138+
if self.rule == 'diff':
139+
for m in range(0, self.k-1):
140+
tmp_score = relevance[search_space] - \
141+
np.mean(redundancy[:, search_space]
142+
.take(mask, axis=0), 0)
143+
score.append(max(tmp_score))
144+
ind = tmp_score.argmax(0)
145+
mask.append(search_space[ind])
146+
search_space.pop(ind)
147+
148+
elif self.rule == 'prod':
149+
for m in range(0, self.k-1):
150+
tmp_score = relevance[search_space] * \
151+
np.mean(redundancy[:, search_space]
152+
.take(mask, axis=0), 0)
153+
score.append(max(tmp_score))
154+
ind = tmp_score.argmax(0)
155+
mask.append(search_space[ind])
156+
search_space.pop(ind)
164157
else:
165158
raise ValueError("rule should be either 'diff' or 'prod'")
166159

167-
return fun
160+
return mask, score

sklearn/feature_selection/tests/test_multivariate_filtering.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
y = np.array([3, 1, 3, 1, 3])
1414

15+
1516
def test_mMRM():
1617
"""
1718
Test MinRedundancyMaxRelevance with default setting.
@@ -23,4 +24,8 @@ def test_mMRM():
2324

2425
assert_array_equal(0.6730116670092563, m.score[0])
2526

26-
assert_raises(ValueError, MinRedundancyMaxRelevance, rule='none')
27+
m = MinRedundancyMaxRelevance(rule='prod').fit(X, y)
28+
29+
assert_array_equal(0.049793044493117354, m.score[1])
30+
31+
assert_raises(ValueError, MinRedundancyMaxRelevance(rule='none').fit, X, y)

0 commit comments

Comments
 (0)
0