|
16 | 16 | from sklearn.utils.testing import assert_less
|
17 | 17 | from sklearn.utils.testing import assert_warns
|
18 | 18 | from sklearn.utils.testing import ignore_warnings
|
| 19 | +from sklearn.utils.testing import assert_warns_message |
19 | 20 | from sklearn.utils import safe_mask
|
20 | 21 |
|
21 | 22 | from sklearn.datasets.samples_generator import (make_classification,
|
@@ -251,10 +252,13 @@ def test_select_kbest_zero():
|
251 | 252 | shuffle=False, random_state=0)
|
252 | 253 |
|
253 | 254 | univariate_filter = SelectKBest(f_classif, k=0)
|
254 |
| - univariate_filter.fit(X, y).transform(X) |
| 255 | + univariate_filter.fit(X, y) |
255 | 256 | support = univariate_filter.get_support()
|
256 | 257 | gtruth = np.zeros(10, dtype=bool)
|
257 | 258 | assert_array_equal(support, gtruth)
|
| 259 | + X_selected = assert_warns_message(UserWarning, 'No features were selected', |
| 260 | + univariate_filter.transform, X) |
| 261 | + assert_equal(X_selected.shape, (20, 0)) |
258 | 262 |
|
259 | 263 |
|
260 | 264 | def test_select_fpr_classif():
|
@@ -585,3 +589,24 @@ def test_f_classif_constant_feature():
|
585 | 589 | X, y = make_classification(n_samples=10, n_features=5)
|
586 | 590 | X[:, 0] = 2.0
|
587 | 591 | assert_warns(UserWarning, f_classif, X, y)
|
| 592 | + |
| 593 | + |
| 594 | +def test_no_feature_selected(): |
| 595 | + rng = np.random.RandomState(0) |
| 596 | + |
| 597 | + # Generate random uncorrelated data: a strict univariate test should |
| 598 | + # rejects all the features |
| 599 | + X = rng.rand(40, 10) |
| 600 | + y = rng.randint(0, 4, size=40) |
| 601 | + strict_selectors = [ |
| 602 | + SelectFwe(alpha=0.01).fit(X, y), |
| 603 | + SelectFdr(alpha=0.01).fit(X, y), |
| 604 | + SelectFpr(alpha=0.01).fit(X, y), |
| 605 | + SelectPercentile(percentile=0).fit(X, y), |
| 606 | + SelectKBest(k=0).fit(X, y), |
| 607 | + ] |
| 608 | + for selector in strict_selectors: |
| 609 | + assert_array_equal(selector.get_support(), np.zeros(10)) |
| 610 | + X_selected = assert_warns_message( |
| 611 | + UserWarning, 'No features were selected', selector.transform, X) |
| 612 | + assert_equal(X_selected.shape, (40, 0)) |
0 commit comments