8000 FIX Ignore zero sample weights in precision recall curve (#18328) · thomasjpfan/scikit-learn@d4696c0 · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit d4696c0

Browse files
albertvillanovaAlonso Silva Allende
authored andcommitted
FIX Ignore zero sample weights in precision recall curve (scikit-learn#18328)
Co-authored-by: Alonso Silva Allende <alonsosilva@gmaiil.com>
1 parent 2f5fdea commit d4696c0

File tree

4 files changed

+89
-66
lines changed

4 files changed

+89
-66
lines changed

doc/whats_new/v1.0.rst

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,12 @@ Changelog
241241
are integral.
242242
:pr:`9843` by :user:`Jon Crall <Erotemic>`.
243243

244+
- |Fix| Samples with zero `sample_weight` values do not affect the results
245+
from :func:`metrics.det_curve`, :func:`metrics.precision_recall_curve`
246+
and :func:`metrics.roc_curve`.
247+
:pr:`18328` by :user:`Albert Villanova del Moral <albertvillanova>` and
248+
:user:`Alonso Silva Allende <alonsosilvaallende>`.
249+
244250
:mod:`sklearn.model_selection`
245251
..............................
246252

@@ -325,9 +331,9 @@ Changelog
325331
:pr:`19459` by :user:`Cindy Bezuidenhout <cinbez>` and
326332
:user:`Clifford Akai-Nettey<cliffordEmmanuel>`.
327333

328-
- |Fix| Fixed a bug in :func:`utils.sparsefuncs.mean_variance_axis` where the
329-
precision of the computed variance was very poor when the real variance is
330-
exactly zero. :pr:`19766` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
334+
- |Fix| Fixed a bug in :func:`utils.sparsefuncs.mean_variance_axis` where the
335+
precision of the computed variance was very poor when the real variance is
336+
exactly zero. :pr:`19766` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
331337

332338
Code and Documentation Contributors
333339
-----------------------------------

sklearn/metrics/_ranking.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from ..utils import assert_all_finite
2929
from ..utils import check_consistent_length
30+
from ..utils.validation import _check_sample_weight
3031
from ..utils import column_or_1d, check_array
3132
from ..utils.multiclass import type_of_target
3233
from ..utils.extmath import stable_cumsum
@@ -291,14 +292,14 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None):
291292
>>> thresholds
292293
array([0.35, 0.4 , 0.8 ])
293294
"""
294-
if len(np.unique(y_true)) != 2:
295-
raise ValueError("Only one class present in y_true. Detection error "
296-
"tradeoff curve is not defined in that case.")
297-
298295
fps, tps, thresholds = _binary_clf_curve(
299296
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
300297
)
301298

299+
if len(np.unique(y_true)) != 2:
300+
raise ValueError("Only one class present in y_true. Detection error "
301+
"tradeoff curve is not defined in that case.")
302+
302303
fns = tps[-1] - tps
303304
p_count = tps[-1]
304305
n_count = fps[-1]
@@ -696,8 +697,14 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):
696697
assert_all_finite(y_true)
697698
assert_all_finite(y_score)
698699

700+
# Filter out zero-weighted samples, as they should not impact the result
699701
if sample_weight is not None:
700702
sample_weight = column_or_1d(sample_weight)
703+
sample_weight = _check_sample_weight(sample_weight, y_true)
704+
nonzero_weight_mask = sample_weight != 0
705+
y_true = y_true[nonzero_weight_mask]
706+
y_score = y_score[nonzero_weight_mask]
707+
sample_weight = sample_weight[nonzero_weight_mask]
701708

702709
pos_label = _check_pos_label_consistency(pos_label, y_true)
703710

@@ -759,7 +766,9 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None,
759766
pos_label should be explicitly given.
760767
761768
probas_pred : ndarray of shape (n_samples,)
762-
Estimated probabilities or output of a decision function.
769+
Target scores, can either be probability estimates of the positive
770+
class, or non-thresholded measure of decisions (as returned by
771+
`decision_function` on some classifiers).
763772
764773
pos_label : int or str, default=None
765774
The label of the positive class.

sklearn/metrics/tests/test_ranking.py

Lines changed: 64 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@
4141
###############################################################################
4242
# Utilities for testing
4343

44+
CURVE_FUNCS = [
45+
det_curve,
46+
precision_recall_curve,
47+
roc_curve,
48+
]
49+
50+
4451
def make_prediction(dataset=None, binary=False):
4552
"""Make some classification predictions on a toy dataset using a SVC
4653
@@ -73,16 +80,16 @@ def make_prediction(dataset=None, binary=False):
7380

7481
# run classifier, get class probabilities and label predictions
7582
clf = svm.SVC(kernel='linear', probability=True, random_state=0)
76-
probas_pred = clf.fit(X[:half], y[:half]).predict_proba(X[half:])
83+
y_score = clf.fit(X[:half], y[:half]).predict_proba(X[half:])
7784

7885
if binary:
7986
# only interested in probabilities of the positive case
8087
# XXX: do we really want a special API for the binary case?
81-
probas_pred = probas_pred[:, 1]
88+
y_score = y_score[:, 1]
8289

8390
y_pred = clf.predict(X[half:])
8491
y_true = y[half:]
85-
return y_true, y_pred, probas_pred
92+
return y_true, y_pred, y_score
8693

8794

8895
###############################################################################
@@ -183,14 +190,14 @@ def _partial_roc(y_true, y_predict, max_fpr):
183190
@pytest.mark.parametrize('drop', [True, False])
184191
def test_roc_curve(drop):
185192
# Test Area under Receiver Operating Characteristic (ROC) curve
186-
y_true, _, probas_pred = make_prediction(binary=True)
187-
expected_auc = _auc(y_true, probas_pred)
193+
y_true, _, y_score = make_prediction(binary=True)
194+
expected_auc = _auc(y_true, y_score)
188195

189-
fpr, tpr, thresholds = roc_curve(y_true, probas_pred,
196+
fpr, tpr, thresholds = roc_curve(y_true, y_score,
190197
drop_intermediate=drop)
191198
roc_auc = auc(fpr, tpr)
192199
assert_array_almost_equal(roc_auc, expected_auc, decimal=2)
193-
assert_almost_equal(roc_auc, roc_auc_score(y_true, probas_pred))
200+
assert_almost_equal(roc_auc, roc_auc_score(y_true, y_score))
194201
assert fpr.shape == tpr.shape
195202
assert fpr.shape == thresholds.shape
196203

@@ -211,13 +218,13 @@ def test_roc_curve_end_points():
211218
def test_roc_returns_consistency():
212219
# Test whether the returned threshold matches up with tpr
213220
# make small toy dataset
214-
y_true, _, probas_pred = make_prediction(binary=True)
215-
fpr, tpr, thresholds = roc_curve(y_true, probas_pred)
221+
y_true, _, y_score = make_prediction(binary=True)
222+
fpr, tpr, thresholds = roc_curve(y_true, y_score)
216223

217224
# use the given thresholds to determine the tpr
218225
tpr_correct = []
219226
for t in thresholds:
220-
tp = np.sum((probas_pred >= t) & y_true)
227+
tp = np.sum((y_score >= t) & y_true)
221228
p = np.sum(y_true)
222229
tpr_correct.append(1.0 * tp / p)
223230

@@ -229,17 +236,17 @@ def test_roc_returns_consistency():
229236

230237
def test_roc_curve_multi():
231238
# roc_curve not applicable for multi-class problems
232-
y_true, _, probas_pred = make_prediction(binary=False)
239+
y_true, _, y_score = make_prediction(binary=False)
233240

234241
with pytest.raises(ValueError):
235-
roc_curve(y_true, probas_pred)
242+
roc_curve(y_true, y_score)
236243

237244

238245
def test_roc_curve_confidence():
239246
# roc_curve for confidence scores
240-
y_true, _, probas_pred = make_prediction(binary=True)
247+
y_true, _, y_score = make_prediction(binary=True)
241248

242-
fpr, tpr, thresholds = roc_curve(y_true, probas_pred - 0.5)
249+
fpr, tpr, thresholds = roc_curve(y_true, y_score - 0.5)
243250
roc_auc = auc(fpr, tpr)
244251
assert_array_almost_equal(roc_auc, 0.90, decimal=2)
245252
assert fpr.shape == tpr.shape
@@ -248,7 +255,7 @@ def test_roc_curve_confidence():
248255

249256
def test_roc_curve_hard():
250257
# roc_curve for hard decisions
251-
y_true, pred, probas_pred = make_prediction(binary=True)
258+
y_true, pred, y_score = make_prediction(binary=True)
252259

253260
# always predict one
254261
trivial_pred = np.ones(y_true.shape)
@@ -668,23 +675,17 @@ def test_auc_score_non_binary_class():
668675
roc_auc_score(y_true, y_pred)
669676

670677

671-
def test_binary_clf_curve_multiclass_error():
678+
@pytest.mark.parametrize("curve_func", CURVE_FUNCS)
679+
def test_binary_clf_curve_multiclass_error(curve_func):
672680
rng = check_random_state(404)
673681
y_true = rng.randint(0, 3, size=10)
674682
y_pred = rng.rand(10)
675683
msg = "multiclass format is not supported"
676-
677684
with pytest.raises(ValueError, match=msg):
678-
precision_recall_curve(y_true, y_pred)
679-
680-
with pytest.raises(ValueError, match=msg):
681-
roc_curve(y_true, y_pred)
685+
curve_func(y_true, y_pred)
682686

683687

684-
@pytest.mark.parametrize("curve_func", [
685-
precision_recall_curve,
686-
roc_curve,
687-
])
688+
@pytest.< D3FE span class=pl-c1>mark.parametrize("curve_func", CURVE_FUNCS)
688689
def test_binary_clf_curve_implicit_pos_label(curve_func):
689690
# Check that using string class labels raises an informative
690691
# error for any supported string dtype:
@@ -693,10 +694,10 @@ def test_binary_clf_curve_implicit_pos_label(curve_func):
693694
"value in {0, 1} or {-1, 1} or pass pos_label "
694695
"explicitly.")
695696
with pytest.raises(ValueError, match=msg):
696-
roc_curve(np.array(["a", "b"], dtype='<U1'), [0., 1.])
697+
curve_func(np.array(["a", "b"], dtype='<U1'), [0., 1.])
697698

698699
with pytest.raises(ValueError, match=msg):
699-
roc_curve(np.array(["a", "b"], dtype=object), [0., 1.])
700+
curve_func(np.array(["a", "b"], dtype=object), [0., 1.])
700701

701702
# The error message is slightly different for bytes-encoded
702703
# class labels, but otherwise the behavior is the same:
@@ -705,25 +706,39 @@ def test_binary_clf_curve_implicit_pos_label(curve_func):
705706
"value in {0, 1} or {-1, 1} or pass pos_label "
706707
"explicitly.")
707708
with pytest.raises(ValueError, match=msg):
708-
roc_curve(np.array([b"a", b"b"], dtype='<S1'), [0., 1.])
709+
curve_func(np.array([b"a", b"b"], dtype='<S1'), [0., 1.])
709710

710711
# Check that it is possible to use floating point class labels
711712
# that are interpreted similarly to integer class labels:
712713
y_pred = [0., 1., 0.2, 0.42]
713-
int_curve = roc_curve([0, 1, 1, 0], y_pred)
714-
float_curve = roc_curve([0., 1., 1., 0.], y_pred)
714+
int_curve = curve_func([0, 1, 1, 0], y_pred)
715+
float_curve = curve_func([0., 1., 1., 0.], y_pred)
715716
for int_curve_part, float_curve_part in zip(int_curve, float_curve):
716717
np.testing.assert_allclose(int_curve_part, float_curve_part)
717718

718719

720+
@pytest.mark.parametrize("curve_func", CURVE_FUNCS)
721+
def test_binary_clf_curve_zero_sample_weight(curve_func):
722+
y_true = [0, 0, 1, 1, 1]
723+
y_score = [0.1, 0.2, 0.3, 0.4, 0.5]
724+
sample_weight = [1, 1, 1, 0.5, 0]
725+
726+
result_1 = curve_func(y_true, y_score, sample_weight=sample_weight)
727+
result_2 = curve_func(y_true[:-1], y_score[:-1],
728+
sample_weight=sample_weight[:-1])
729+
730+
for arr_1, arr_2 in zip(result_1, result_2):
731+
assert_allclose(arr_1, arr_2)
732+
733+
719734
def test_precision_recall_curve():
720-
y_true, _, probas_pred = make_prediction(binary=True)
721-
_test_precision_recall_curve(y_true, probas_pred)
735+
y_true, _, y_score = make_prediction(binary=True)
736+
_test_precision_recall_curve(y_true, y_score)
722737

723738
# Use {-1, 1} for labels; make sure original labels aren't modified
724739
y_true[np.where(y_true == 0)] = -1
725740
y_true_copy = y_true.copy()
726-
_test_precision_recall_curve(y_true, probas_pred)
741+
_test_precision_recall_curve(y_true, y_score)
727742
assert_array_equal(y_true_copy, y_true)
728743

729744
labels = [1, 0, 0, 1]
@@ -736,31 +751,24 @@ def test_precision_recall_curve():
736751
assert p.size == t.size + 1
737752

738753

739-
def _test_precision_recall_curve(y_true, probas_pred):
754+
def _test_precision_recall_curve(y_true, y_score):
740755
# Test Precision-Recall and aread under PR curve
741-
p, r, thresholds = precision_recall_curve(y_true, probas_pred)
742-
precision_recall_auc = _average_precision_slow(y_true, probas_pred)
756+
p, r, thresholds = precision_recall_curve(y_true, y_score)
757+
precision_recall_auc = _average_precision_slow(y_true, y_score)
743758
assert_array_almost_equal(precision_recall_auc, 0.859, 3)
744759
assert_array_almost_equal(precision_recall_auc,
745-
average_precision_score(y_true, probas_pred))
760+
average_precision_score(y_true, y_score))
746761
# `_average_precision` is not very precise in case of 0.5 ties: be tolerant
747-
assert_almost_equal(_average_precision(y_true, probas_pred),
762+
assert_almost_equal(_average_precision(y_true, y_score),
748763
precision_recall_auc, decimal=2)
749764
assert p.size == r.size
750765
assert p.size == thresholds.size + 1
751766
# Smoke test in the case of proba having only one value
752-
p, r, thresholds = precision_recall_curve(y_true,
753-
np.zeros_like(probas_pred))
767+
p, r, thresholds = precision_recall_curve(y_true, np.zeros_like(y_score))
754768
assert p.size == r.size
755769
assert p.size == thresholds.size + 1
756770

757771

758-
def test_precision_recall_curve_errors():
759-
# Contains non-binary labels
760-
with pytest.raises(ValueError):
761-
precision_recall_curve([0, 1, 2], [[0.0], [1.0], [1.0]])
762-
763-
764772
def test_precision_recall_curve_toydata():
765773
with np.errstate(all="raise"):
766774
# Binary classification
@@ -913,20 +921,20 @@ def test_score_scale_invariance():
913921
# This test was expanded (added scaled_down) in response to github
914922
# issue #3864 (and others), where overly aggressive rounding was causing
915923
# problems for users with very small y_score values
916-
y_true, _, probas_pred = make_prediction(binary=True)
924+
y_true, _, y_score = make_prediction(binary=True)
917925

918-
roc_auc = roc_auc_score(y_true, probas_pred)
919-
roc_auc_scaled_up = roc_auc_score(y_true, 100 * probas_pred)
920-
roc_auc_scaled_down = roc_auc_score(y_true, 1e-6 * probas_pred)
921-
roc_auc_shifted = roc_auc_score(y_true, probas_pred - 10)
926+
roc_auc = roc_auc_score(y_true, y_score)
927+
roc_auc_scaled_up = roc_auc_score(y_true, 100 * y_score)
928+
roc_auc_scaled_down = roc_auc_score(y_true, 1e-6 * y_score)
929+
roc_auc_shifted = roc_auc_score(y_true, y_score - 10)
922930
assert roc_auc == roc_auc_scaled_up
923931
assert roc_auc == roc_auc_scaled_down
924932
assert roc_auc == roc_auc_shifted
925933

926-
pr_auc = average_precision_score(y_true, probas_pred)
927-
pr_auc_scaled_up = average_precision_score(y_true, 100 * probas_pred)
928-
pr_auc_scaled_down = average_precision_score(y_true, 1e-6 * probas_pred)
929-
pr_auc_shifted = average_precision_score(y_true, probas_pred - 10)
934+
pr_auc = average_precision_score(y_true, y_score)
935+
pr_auc_scaled_up = average_precision_score(y_true, 100 * y_score)
936+
pr_auc_scaled_down = average_precision_score(y_true, 1e-6 * y_score)
937+
pr_auc_shifted = average_precision_score(y_true, y_score - 10)
930938
assert pr_auc == pr_auc_scaled_up
931939
assert pr_auc == pr_auc_scaled_down
932940
assert pr_auc == pr_auc_shifted
@@ -954,8 +962,7 @@ def test_score_scale_invariance():
954962
([1, 0, 1], [0.5, 0.75, 1], [1, 1, 0], [0, 0.5, 0.5]),
955963
([1, 0, 1], [0.25, 0.5, 0.75], [1, 1, 0], [0, 0.5, 0.5]),
956964
])
957-
def test_det_curve_toydata(y_true, y_score,
958-
expected_fpr, expected_fnr):
965+
def test_det_curve_toydata(y_true, y_score, expected_fpr, expected_fnr):
959966
# Check on a batch of small examples.
960967
fpr, fnr, _ = det_curve(y_true, y_score)
961968

sklearn/utils/validation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1346,7 +1346,7 @@ def _check_sample_weight(sample_weight, X, dtype=None, copy=False):
13461346
X : {ndarray, list, sparse matrix}
13471347
Input data.
13481348
1349-
dtype: dtype, default=None
1349+
dtype : dtype, default=None
13501350
dtype of the validated `sample_weight`.
13511351
If None, and the input `sample_weight` is an array, the dtype of the
13521352
input is preserved; otherwise an array with the default numpy dtype
@@ -1383,6 +1383,7 @@ def _check_sample_weight(sample_weight, X, dtype=None, copy=False):
13831383
if sample_weight.shape != (n_samples,):
13841384
raise ValueError("sample_weight.shape == {}, expected {}!"
13851385
.format(sample_weight.shape, (n_samples,)))
1386+
13861387
return sample_weight
13871388

13881389

0 commit comments

Comments
 (0)
0