8000 Merge pull request #4146 from amueller/fdr_treshold_bug_2 · ogrisel/scikit-learn@92da7c0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 92da7c0

Browse files
committed
Merge pull request scikit-learn#4146 from amueller/fdr_treshold_bug_2
[MRG + 1] Fdr treshold bug
2 parents 3468e00 + f12a441 commit 92da7c0

File tree

3 files changed

+81
-87
lines changed

3 files changed

+81
-87
lines changed

doc/whats_new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,9 @@ Bug fixes
306306
- Fix log-density calculation in the :class:`mixture.GMM` with
307307
tied covariance. By `Will Dawson`_
308308

309+
- Fixed a scaling error in :class:`feature_selection.SelectFdr`
310+
where a factor ``n_features`` was missing. By `Andrew Tulloch`_
311+
309312
API changes summary
310313
-------------------
311314

@@ -3286,3 +3289,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
32863289
.. _Will Dawson: http://dawsonresearch.com
32873290

32883291
.. _Balazs Kegl: https://github.com/kegl
3292+
329 10000 3+
.. _Andrew Tulloch: http://tullo.ch/

sklearn/feature_selection/tests/test_feature_select.py

Lines changed: 66 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""
22
Todo: cross-check the F-value with stats model
33
"""
4-
4+
from __future__ import division
55
import itertools
6+
import warnings
67
import numpy as np
78
from scipy import stats, sparse
89

@@ -17,6 +18,8 @@
1718
from sklearn.utils.testing import assert_warns
1819
from sklearn.utils.testing import ignore_warnings
1920
from sklearn.utils.testing import assert_warns_message
21+
from sklearn.utils.testing import assert_greater
22+
from sklearn.utils.testing import assert_greater_equal
2023
from sklearn.utils import safe_mask
2124

2225
from sklearn.datasets.samples_generator import (make_classification,
@@ -67,7 +70,7 @@ def test_f_classif():
6770
class_sep=10, shuffle=False, random_state=0)
6871

6972
F, pv = f_classif(X, y)
70-
F_sparse, pv_sparse = f_classif(sparse.csr_matrix(X), y)
73+
F_sparse, pv_sparse = f_classif(sparse.csr_matrix(X), y)
7174
assert_true((F > 0).all())
7275
assert_true((pv > 0).all())
7376
assert_true((pv < 1).all())
@@ -261,57 +264,11 @@ def test_select_kbest_zero():
261264
assert_equal(X_selected.shape, (20, 0))
262265

263266

264-
def test_select_fpr_classif():
267+
def test_select_heuristics_classif():
265268
"""
266269
Test whether the relative univariate feature selection
267270
gets the correct items in a simple classification problem
268-
with the fpr heuristic
269-
"""
270-
X, y = make_classification(n_samples=200, n_features=20,
271-
n_informative=3, n_redundant=2,
272-
n_repeated=0, n_classes=8,
273-
n_clusters_per_class=1, flip_y=0.0,
274-
class_sep=10, shuffle=False, random_state=0)
275-
276-
univariate_filter = SelectFpr(f_classif, alpha=0.0001)
277-
X_r = univariate_filter.fit(X, y).transform(X)
278-
X_r2 = GenericUnivariateSelect(
279-
f_classif, mode='fpr', param=0.0001).fit(X, y).transform(X)
280-
assert_array_equal(X_r, X_r2)
281-
support = univariate_filter.get_support()
282-
gtruth = np.zeros(20)
283-
gtruth[:5] = 1
284-
assert_array_equal(support, gtruth)
285-
286-
287-
def test_select_fdr_classif():
288-
"""
289-
Test whether the relative univariate feature selection
290-
gets the correct items in a simple classification problem
291-
with the fpr heuristic
292-
"""
293-
X, y = make_classification(n_samples=200, n_features=20,
294-
n_informative=3, n_redundant=2,
295-
n_repeated=0, n_classes=8,
296-
n_clusters_per_class=1, flip_y=0.0,
297-
class_sep=10, shuffle=False, random_state=0)
298-
299-
univariate_filter = SelectFdr(f_classif, alpha=0.0001)
300-
X_r = univariate_filter.fit(X, y).transform(X)
301-
X_r2 = GenericUnivariateSelect(
302-
f_classif, mode='fdr', param=0.0001).fit(X, y).transform(X)
303-
assert_array_equal(X_r, X_r2)
304-
support = univariate_filter.get_support()
305-
gtruth = np.zeros(20)
306-
gtruth[:5] = 1
307-
assert_array_equal(support, gtruth)
308-
309-
310-
def test_select_fwe_classif():
311-
"""
312-
Test whether the relative univariate feature selection
313-
gets the correct items in a simple classification problem
314-
with the fpr heuristic
271+
with the fdr, fwe and fpr heuristics
315272
"""
316273
X, y = make_classification(n_samples=200, n_features=20,
317274
n_informative=3, n_redundant=2,
@@ -321,13 +278,14 @@ def test_select_fwe_classif():
321278

322279
univariate_filter = SelectFwe(f_classif, alpha=0.01)
323280
X_r = univariate_filter.fit(X, y).transform(X)
324-
X_r2 = GenericUnivariateSelect(
325-
f_classif, mode='fwe', param=0.01).fit(X, y).transform(X)
326-
assert_array_equal(X_r, X_r2)
327-
support = univariate_filter.get_support()
328281
gtruth = np.zeros(20)
329282
gtruth[:5] = 1
330-
assert_array_almost_equal(support, gtruth)
283+
for mode in ['fdr', 'fpr', 'fwe']:
284+
X_r2 = GenericUnivariateSelect(
285+
f_classif, mode=mode, param=0.01).fit(X, y).transform(X)
286+
assert_array_equal(X_r, X_r2)
287+
support = univariate_filter.get_support()
288+
assert_array_almost_equal(support, gtruth)
331289

332290

333291
##############################################################################
@@ -405,8 +363,8 @@ def test_select_kbest_regression():
405363
gets the correct items in a simple regression problem
406364
with the k best heuristic
407365
"""
408-
X, y = make_regression(n_samples=200, n_features=20,
409-
n_informative=5, shuffle=False, random_state=0)
366+
X, y = make_regression(n_samples=200, n_features=20, n_informative=5,
367+
shuffle=False, random_state=0, noise=10)
410368

411369
univariate_filter = SelectKBest(f_regression, k=5)
412370
X_r = univariate_filter.fit(X, y).transform(X)
@@ -420,45 +378,70 @@ def test_select_kbest_regression():
420378
assert_array_equal(support, gtruth)
421379

422380

423-
def test_select_fpr_regression():
381+
def test_select_heuristics_regression():
4 10000 24382
"""
425383
Test whether the relative univariate feature selection
426384
gets the correct items in a simple regression problem
427-
with the fpr heuristic
385+
with the fpr, fdr or fwe heuristics
428386
"""
429-
X, y = make_regression(n_samples=200, n_features=20,
430-
n_informative=5, shuffle=False, random_state=0)
387+
X, y = make_regression(n_samples=200, n_features=20, n_informative=5,
388+
shuffle=False, random_state=0, noise=10)
431389

432390
univariate_filter = SelectFpr(f_regression, alpha=0.01)
433391
X_r = univariate_filter.fit(X, y).transform(X)
434-
X_r2 = GenericUnivariateSelect(
435-
f_regression, mode='fpr', param=0.01).fit(X, y).transform(X)
436-
assert_array_equal(X_r, X_r2)
437-
support = univariate_filter.get_support()
438392
gtruth = np.zeros(20)
439393
gtruth[:5] = 1
440-
assert_array_equal(support[:5], np.ones((5, ), dtype=np.bool))
441-
assert_less(np.sum(support[5:] == 1), 3)
394+
for mode in ['fdr', 'fpr', 'fwe']:
395+
X_r2 = GenericUnivariateSelect(
396+
f_regression, mode=mode, param=0.01).fit(X, y).transform(X)
397+
assert_array_equal(X_r, X_r2)
398+
support = univariate_filter.get_support()
399+
assert_array_equal(support[:5], np.ones((5, ), dtype=np.bool))
400+
assert_less(np.sum(support[5:] == 1), 3)
442401

443402

444403
def test_select_fdr_regression():
445404
"""
446-
Test whether the relative univariate feature selection
447-
gets the correct items in a simple regression problem
448-
with the fdr heuristic
449-
"""
450-
X, y = make_regression(n_samples=200, n_features=20,
451-
n_informative=5, shuffle=False, random_state=0)
452-
453-
univariate_filter = SelectFdr(f_regression, alpha=0.01)
454-
X_r = univariate_filter.fit(X, y).transform(X)
455-
X_r2 = GenericUnivariateSelect(
456-
f_regression, mode='fdr', param=0.01).fit(X, y).transform(X)
457-
assert_array_equal(X_r, X_r2)
458-
support = univariate_filter.get_support()
459-
gtruth = np.zeros(20)
460-
gtruth[:5] = 1
461-
assert_array_equal(support, gtruth)
405+
Test that fdr heuristic actually has low FDR.
406+
"""
407+
def single_fdr(alpha, n_informative, random_state):
408+
X, y = make_regression(n_samples=150, n_features=20,
409+
n_informative=n_informative, shuffle=False,
410+
random_state=random_state, noise=10)
411+
412+
with warnings.catch_warnings(record=True):
413+
# Warnings can be raised when no features are selected
414+
# (low alpha or very noisy data)
415+
univariate_filter = SelectFdr(f_regression, alpha=alpha)
416+
X_r = univariate_filter.fit(X, y).transform(X)
417+
X_r2 = GenericUnivariateSelect(
418+
f_regression, mode='fdr', param=alpha).fit(X, y).transform(X)
419+
420+
assert_array_equal(X_r, X_r2)
421+
support = univariate_filter.get_support()
422+
num_false_positives = np.sum(support[n_informative:] == 1)
423+
num_true_positives = np.sum(support[:n_informative] == 1)
424+
425+
if num_false_positives == 0:
426+
return 0.
427+
false_discovery_rate = (num_false_positives /
428+
(num_true_positives + num_false_positives))
429+
return false_discovery_rate
430+
431+
for alpha in [0.001, 0.01, 0.1]:
432+
for n_informative in [1, 5, 10]:
433+
# As per Benjamini-Hochberg, the expected false discovery rate
434+
# should be lower than alpha:
435+
# FDR = E(FP / (TP + FP)) <= alpha
436+
false_discovery_rate = np.mean([single_fdr(alpha, n_informative,
437+
random_state) for
438+
random_state in range(30)])
439+
assert_greater_equal(alpha, false_discovery_rate)
440+
441+
# Make sure that the empirical false discovery rate increases
442+
# with alpha:
443+
if false_discovery_rate != 0:
444+
assert_greater(false_discovery_rate, alpha / 10)
462445

463446

464447
def test_select_fwe_regression():

sklearn/feature_selection/univariate_selection.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -465,8 +465,8 @@ def _get_support_mask(self):
465465
class SelectFdr(_BaseFilter):
466466
"""Filter: Select the p-values for an estimated false discovery rate
467467
468-
This uses the Benjamini-Hochberg procedure. ``alpha`` is the target false
469-
discovery rate.
468+
This uses the Benjamini-Hochberg procedure. ``alpha`` is an upper bound
469+
on the expected false discovery rate.
470470
471471
Parameters
472472
----------
@@ -485,6 +485,11 @@ class SelectFdr(_BaseFilter):
485485
486486
pvalues_ : array-like, shape=(n_features,)
487487
p-values of feature scores.
488+
489+
References
490+
----------
491+
http://en.wikipedia.org/wiki/False_discovery_rate
492+
488493
"""
489494

490495
def __init__(self, score_func=f_classif, alpha=5e-2):
@@ -494,9 +499,10 @@ def __init__(self, score_func=f_classif, alpha=5e-2):
494499
def _get_support_mask(self):
495500
check_is_fitted(self, 'scores_')
496501

497-
alpha = self.alpha
502+
n_features = len(self.pvalues_)
498503
sv = np.sort(self.pvalues_)
499-
selected = sv[sv < alpha * np.arange(len(self.pvalues_))]
504+
selected = sv[sv <= float(self.alpha) / n_features
505+
* np.arange(n_features)]
500506
if selected.size == 0:
501507
return np.zeros_like(self.pvalues_, dtype=bool)
502508
return self.pvalues_ <= selected.max()

0 commit comments

Comments
 (0)
0