3
3
"""
4
4
import numpy as np
5
5
from numpy .testing import assert_almost_equal
6
+ from numpy .testing import assert_array_equal
6
7
from sklearn .utils .testing import assert_raise_message
7
8
from sklearn .neighbors import KNeighborsClassifier
8
9
from sklearn .linear_model import LinearRegression
@@ -37,7 +38,8 @@ def test_run_default():
37
38
knn = KNeighborsClassifier ()
38
39
sfs = SFS (estimator = knn )
39
40
sfs .fit (X , y )
40
- assert sfs .feature_subset_idx_ == (3 , )
41
+ assert_array_equal (sfs .support_ ,
42
+ np .array ([False , False , False , True ]))
41
43
42
44
43
45
def test_kfeatures_type_1 ():
@@ -163,7 +165,8 @@ def test_knn_option_sbs():
163
165
forward = False ,
164
166
cv = 4 )
165
167
sfs3 = sfs3 .fit (X , y )
166
- assert sfs3 .feature_subset_idx_ == (1 , 2 , 3 )
168
+ assert_array_equal (sfs3 .support_ ,
169
+ np .array ([False , True , True , True ]))
167
170
168
171
169
172
def test_knn_option_sfs_tuplerange ():
@@ -177,7 +180,8 @@ def test_knn_option_sfs_tuplerange():
177
180
cv = 4 )
178
181
sfs4 = sfs4 .fit (X , y )
179
182
assert round (sfs4 .score_ , 3 ) == 0.967
180
- assert sfs4 .feature_subset_idx_ == (0 , 2 , 3 )
183
+ assert_array_equal (sfs4 .support_ ,
184
+ np .array ([True , False , True , True ]))
181
185
182
186
183
187
def test_knn_scoring_metric ():
@@ -217,7 +221,7 @@ def test_regression():
217
221
forward = True ,
218
222
cv = 10 )
219
223
sfs_r = sfs_r .fit (X , y )
220
- assert len (sfs_r .feature_subset_idx_ ) == 13
224
+ assert sum (sfs_r .support_ ) == 13
221
225
assert round (sfs_r .score_ , 4 ) == 0.2001
222
226
223
227
@@ -233,7 +237,7 @@ def test_regression_in_tuplerange_forward():
233
237
forward = True ,
234
238
cv = 10 )
235
239
sfs_r = sfs_r .fit (X , y )
236
- assert len (sfs_r .feature_subset_idx_ ) == 9
240
+ assert sum (sfs_r .support_ ) == 9
237
241
assert round (sfs_r .score_ , 4 ) == 0.2991 , sfs_r .score_
238
242
239
243
@@ -252,7 +256,12 @@ def test_regression_in_tuplerange_backward():
252
256
cv = 10 )
253
257
254
258
sfs_r = sfs_r .fit (X , y )
255
- assert len (sfs_r .feature_subset_idx_ ) == 5
259
+ print (sfs_r .support_ )
260
+ print (type (sfs_r .support_ ))
261
+ assert_array_equal (sfs_r .support_ ,
262
+ np .array ([False , False , False , False ,
263
+ True , False , False , True ,
264
+ True , False , True , False , True ]))
256
265
257
266
258
267
def test_transform_not_fitted ():
0 commit comments