8000 Merge pull request #4206 from ogrisel/fix-strict-select-fdr · scikit-learn/scikit-learn@95681ee · GitHub
[go: up one dir, main page]

Skip to content

Commit 95681ee

Browse files
committed
Merge pull request #4206 from ogrisel/fix-strict-select-fdr
[MRG] explicit warning message for strict selectors Also fixes #4059
2 parents eaf1e8c + 2ce9eac commit 95681ee

File tree

3 files changed

+36
-3
lines changed

3 files changed

+36
-3
lines changed

sklearn/feature_selection/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# License: BSD 3 clause
66

77
from abc import ABCMeta, abstractmethod
8+
from warnings import warn
89

910
import numpy as np
1011
from scipy.sparse import issparse, csc_matrix
@@ -73,6 +74,11 @@ def transform(self, X):
7374
"""
7475
X = check_array(X, accept_sparse='csr')
7576
mask = self.get_support()
77+
if not mask.any():
78+
warn("No features were selected: either the data is"
79+
" too noisy or the selection test too strict.",
80+
UserWarning)
81+
return np.empty(0).reshape((X.shape[0], 0))
7682
if len(mask) != X.shape[1]:
7783
raise ValueError("X has a different shape than during fitting.")
7884
return check_array(X, accept_sparse='csr')[:, safe_mask(X, mask)]

sklearn/feature_selection/tests/test_feature_select.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from sklearn.utils.testing import assert_less
1717
from sklearn.utils.testing import assert_warns
1818
from sklearn.utils.testing import ignore_warnings
19+
from sklearn.utils.testing import assert_warns_message
1920
from sklearn.utils import safe_mask
2021

2122
from sklearn.datasets.samples_generator import (make_classification,
@@ -251,10 +252,13 @@ def test_select_kbest_zero():
251252
shuffle=False, random_state=0)
252253

253254
univariate_filter = SelectKBest(f_classif, k=0)
254-
univariate_filter.fit(X, y).transform(X)
255+
univariate_filter.fit(X, y)
255256
support = univariate_filter.get_support()
256257
gtruth = np.zeros(10, dtype=bool)
257258
assert_array_equal(support, gtruth)
259+
X_selected = assert_warns_message(UserWarning, 'No features were selected',
260+
univariate_filter.transform, X)
261+
assert_equal(X_selected.shape, (20, 0))
258262

259263

260264
def test_select_fpr_classif():
@@ -585,3 +589,24 @@ def test_f_classif_constant_feature():
585589
X, y = make_classification(n_samples=10, n_features=5)
586590
X[:, 0] = 2.0
587591
assert_warns(UserWarning, f_classif, X, y)
592+
593+
594+
def test_no_feature_selected():
595+
rng = np.random.RandomState(0)
596+
597+
# Generate random uncorrelated data: a strict univariate test should
598+
# rejects all the features
599+
X = rng.rand(40, 10)
600+
y = rng.randint(0, 4, size=40)
601+
strict_selectors = [
602+
SelectFwe(alpha=0.01).fit(X, y),
603+
SelectFdr(alpha=0.01).fit(X, y),
604+
SelectFpr(alpha=0.01).fit(X, y),
605+
SelectPercentile(percentile=0).fit(X, y),
606+
SelectKBest(k=0).fit(X, y),
607+
]
608+
for selector in strict_selectors:
609+
assert_array_equal(selector.get_support(), np.zeros(10))
610+
X_selected = assert_warns_message(
611+
UserWarning, 'No features were selected', selector.transform, X)
612+
assert_equal(X_selected.shape, (40, 0))

sklearn/feature_selection/univariate_selection.py

Lines changed: 4 additions & 2 deletions
< 66E8 td data-grid-cell-id="diff-dcc18206fd6fc16a328ce478589964e4a2f85b0066709a2e222207557812a57b-500-498-0" data-selected="false" role="gridcell" style="background-color:var(--diffBlob-deletionNum-bgColor, var(--diffBlob-deletion-bgColor-num));text-align:center" tabindex="-1" valign="top" class="focusable-grid-cell diff-line-number position-relative left-side">500
Original file line numberDiff line numberDiff line change
@@ -496,8 +496,10 @@ def _get_support_mask(self):
496496

497497
alpha = self.alpha
498498
sv = np.sort(self.pvalues_)
499-
threshold = sv[sv < alpha * np.arange(len(self.pvalues_))].max()
-
return self.pvalues_ <= threshold
499+
selected = sv[sv < alpha * np.arange(len(self.pvalues_))]
500+
if selected.size == 0:
501+
return np.zeros_like(self.pvalues_, dtype=bool)
502+
return self.pvalues_ <= selected.max()
501503

502504

503505
class SelectFwe(_BaseFilter):

0 commit comments

Comments
 (0)
0