@@ -63,7 +63,6 @@ def __init__(self, k=2, rule='diff'):
63
63
"""
64
64
self .k = k
65
65
self .rule = rule
66
- self ._rule_function = self ._get_rule_function (rule )
67
66
68
67
def fit (self , X , y ):
69
68
"""
@@ -136,32 +135,26 @@ def _compute_mRMR(self, X, y):
136
135
mask .append (ind )
137
136
search_space .pop (ind )
138
137
139
- fun = self ._rule_function
140
- for m in range (0 , self .k - 1 ):
141
- tmp_score = fun (relevance [search_space ],
142
- np .mean (redundancy [:, search_space ].
143
- take (mask , axis = 0 ), 0 ))
144
- score .append (max (tmp_score ))
145
- ind = tmp_score .argmax (0 )
146
- mask .append (search_space [ind ])
147
- search_space .pop (ind )
148
-
149
- return mask , score
150
-
151
- def _get_rule_function (self , rule ):
152
- """
153
- Returns
154
- -------
155
- fun : function
156
- Function used to combine relevance (k) and redundancy (h) arrays
157
- """
158
- if rule == 'diff' :
159
- def fun (a , b ):
160
- return a + b
161
- elif rule == 'prod' :
162
- def fun (a , b ):
163
- return a * b
138
+ if self .rule == 'diff' :
139
+ for m in range (0 , self .k - 1 ):
140
+ tmp_score = relevance [search_space ] - \
141
+ np .mean (redundancy [:, search_space ]
142
+ .take (mask , axis = 0 ), 0 )
143
+ score .append (max (tmp_score ))
144
+ ind = tmp_score .argmax (0 )
145
+ mask .append (search_space [ind ])
146
+ search_space .pop (ind )
147
+
148
+ elif self .rule == 'prod' :
149
+ for m in range (0 , self .k - 1 ):
150
+ tmp_score = relevance [search_space ] * \
151
+ np .mean (redundancy [:, search_space ]
152
+ .take (mask , axis = 0 ), 0 )
153
+ score .append (max (tmp_score ))
154
+ ind = tmp_score .argmax (0 )
155
+ mask .append (search_space [ind ])
156
+ search_space .pop (ind )
164
157
else :
165
158
raise ValueError ("rule should be either 'diff' or 'prod'" )
166
159
167
- return fun
160
+ return mask , score
0 commit comments