41
41
###############################################################################
42
42
# Utilities for testing
43
43
44
+ CURVE_FUNCS = [
45
+ det_curve ,
46
+ precision_recall_curve ,
47
+ roc_curve ,
48
+ ]
49
+
50
+
44
51
def make_prediction (dataset = None , binary = False ):
45
52
"""Make some classification predictions on a toy dataset using a SVC
46
53
@@ -73,16 +80,16 @@ def make_prediction(dataset=None, binary=False):
73
80
74
81
# run classifier, get class probabilities and label predictions
75
82
clf = svm .SVC (kernel = 'linear' , probability = True , random_state = 0 )
76
- probas_pred = clf .fit (X [:half ], y [:half ]).predict_proba (X [half :])
83
+ y_score = clf .fit (X [:half ], y [:half ]).predict_proba (X [half :])
77
84
78
85
if binary :
79
86
# only interested in probabilities of the positive case
80
87
# XXX: do we really want a special API for the binary case?
81
- probas_pred = probas_pred [:, 1 ]
88
+ y_score = y_score [:, 1 ]
82
89
83
90
y_pred = clf .predict (X [half :])
84
91
y_true = y [half :]
85
- return y_true , y_pred , probas_pred
92
+ return y_true , y_pred , y_score
86
93
87
94
88
95
###############################################################################
@@ -183,14 +190,14 @@ def _partial_roc(y_true, y_predict, max_fpr):
183
190
@pytest .mark .parametrize ('drop' , [True , False ])
184
191
def test_roc_curve (drop ):
185
192
# Test Area under Receiver Operating Characteristic (ROC) curve
186
- y_true , _ , probas_pred = make_prediction (binary = True )
187
- expected_auc = _auc (y_true , probas_pred )
193
+ y_true , _ , y_score = make_prediction (binary = True )
194
+ expected_auc = _auc (y_true , y_score )
188
195
189
- fpr , tpr , thresholds = roc_curve (y_true , probas_pred ,
196
+ fpr , tpr , thresholds = roc_curve (y_true , y_score ,
190
197
drop_intermediate = drop )
191
198
roc_auc = auc (fpr , tpr )
192
199
assert_array_almost_equal (roc_auc , expected_auc , decimal = 2 )
193
- assert_almost_equal (roc_auc , roc_auc_score (y_true , probas_pred ))
200
+ assert_almost_equal (roc_auc , roc_auc_score (y_true , y_score ))
194
201
assert fpr .shape == tpr .shape
195
202
assert fpr .shape == thresholds .shape
196
203
@@ -211,13 +218,13 @@ def test_roc_curve_end_points():
211
218
def test_roc_returns_consistency ():
212
219
# Test whether the returned threshold matches up with tpr
213
220
# make small toy dataset
214
- y_true , _ , probas_pred = make_prediction (binary = True )
215
- fpr , tpr , thresholds = roc_curve (y_true , probas_pred )
221
+ y_true , _ , y_score = make_prediction (binary = True )
222
+ fpr , tpr , thresholds = roc_curve (y_true , y_score )
216
223
217
224
# use the given thresholds to determine the tpr
218
225
tpr_correct = []
219
226
for t in thresholds :
220
- tp = np .sum ((probas_pred >= t ) & y_true )
227
+ tp = np .sum ((y_score >= t ) & y_true )
221
228
p = np .sum (y_true )
222
229
tpr_correct .append (1.0 * tp / p )
223
230
@@ -229,17 +236,17 @@ def test_roc_returns_consistency():
229
236
230
237
def test_roc_curve_multi ():
231
238
# roc_curve not applicable for multi-class problems
232
- y_true , _ , probas_pred = make_prediction (binary = False )
239
+ y_true , _ , y_score = make_prediction (binary = False )
233
240
234
241
with pytest .raises (ValueError ):
235
- roc_curve (y_true , probas_pred )
242
+ roc_curve (y_true , y_score )
236
243
237
244
238
245
def test_roc_curve_confidence ():
239
246
# roc_curve for confidence scores
240
- y_true , _ , probas_pred = make_prediction (binary = True )
247
+ y_true , _ , y_score = make_prediction (binary = True )
241
248
242
- fpr , tpr , thresholds = roc_curve (y_true , probas_pred - 0.5 )
249
+ fpr , tpr , thresholds = roc_curve (y_true , y_score - 0.5 )
243
250
roc_auc = auc (fpr , tpr )
244
251
assert_array_almost_equal (roc_auc , 0.90 , decimal = 2 )
245
252
assert fpr .shape == tpr .shape
@@ -248,7 +255,7 @@ def test_roc_curve_confidence():
248
255
249
256
def test_roc_curve_hard ():
250
257
# roc_curve for hard decisions
251
- y_true , pred , probas_pred = make_prediction (binary = True )
258
+ y_true , pred , y_score = make_prediction (binary = True )
252
259
253
260
# always predict one
254
261
trivial_pred = np .ones (y_true .shape )
@@ -668,23 +675,17 @@ def test_auc_score_non_binary_class():
668
675
roc_auc_score (y_true , y_pred )
669
676
670
677
671
- def test_binary_clf_curve_multiclass_error ():
678
+ @pytest .mark .parametrize ("curve_func" , CURVE_FUNCS )
679
+ def test_binary_clf_curve_multiclass_error (curve_func ):
672
680
rng = check_random_state (404 )
673
681
y_true = rng .randint (0 , 3 , size = 10 )
674
682
y_pred = rng .rand (10 )
675
683
msg = "multiclass format is not supported"
676
-
677
684
with pytest .raises (ValueError , match = msg ):
678
- precision_recall_curve (y_true , y_pred )
679
-
680
- with pytest .raises (ValueError , match = msg ):
681
- roc_curve (y_true , y_pred )
685
+ curve_func (y_true , y_pred )
682
686
683
687
684
- @pytest .mark .parametrize ("curve_func" , [
685
- precision_recall_curve ,
686
- roc_curve ,
687
- ])
688
+ @pytest .<
D3FE
span class=pl-c1>mark .parametrize ("curve_func" , CURVE_FUNCS )
688
689
def test_binary_clf_curve_implicit_pos_label (curve_func ):
689
690
# Check that using string class labels raises an informative
690
691
# error for any supported string dtype:
@@ -693,10 +694,10 @@ def test_binary_clf_curve_implicit_pos_label(curve_func):
693
694
"value in {0, 1} or {-1, 1} or pass pos_label "
694
695
"explicitly." )
695
696
with pytest .raises (ValueError , match = msg ):
696
- roc_curve (np .array (["a" , "b" ], dtype = '<U1' ), [0. , 1. ])
697
+ curve_func (np .array (["a" , "b" ], dtype = '<U1' ), [0. , 1. ])
697
698
698
699
with pytest .raises (ValueError , match = msg ):
699
- roc_curve (np .array (["a" , "b" ], dtype = object ), [0. , 1. ])
700
+ curve_func (np .array (["a" , "b" ], dtype = object ), [0. , 1. ])
700
701
701
702
# The error message is slightly different for bytes-encoded
702
703
# class labels, but otherwise the behavior is the same:
@@ -705,25 +706,39 @@ def test_binary_clf_curve_implicit_pos_label(curve_func):
705
706
"value in {0, 1} or {-1, 1} or pass pos_label "
706
707
"explicitly." )
707
708
with pytest .raises (ValueError , match = msg ):
708
- roc_curve (np .array ([b"a" , b"b" ], dtype = '<S1' ), [0. , 1. ])
709
+ curve_func (np .array ([b"a" , b"b" ], dtype = '<S1' ), [0. , 1. ])
709
710
710
711
# Check that it is possible to use floating point class labels
711
712
# that are interpreted similarly to integer class labels:
712
713
y_pred = [0. , 1. , 0.2 , 0.42 ]
713
- int_curve = roc_curve ([0 , 1 , 1 , 0 ], y_pred )
714
- float_curve = roc_curve ([0. , 1. , 1. , 0. ], y_pred )
714
+ int_curve = curve_func ([0 , 1 , 1 , 0 ], y_pred )
715
+ float_curve = curve_func ([0. , 1. , 1. , 0. ], y_pred )
715
716
for int_curve_part , float_curve_part in zip (int_curve , float_curve ):
716
717
np .testing .assert_allclose (int_curve_part , float_curve_part )
717
718
718
719
720
+ @pytest .mark .parametrize ("curve_func" , CURVE_FUNCS )
721
+ def test_binary_clf_curve_zero_sample_weight (curve_func ):
722
+ y_true = [0 , 0 , 1 , 1 , 1 ]
723
+ y_score = [0.1 , 0.2 , 0.3 , 0.4 , 0.5 ]
724
+ sample_weight = [1 , 1 , 1 , 0.5 , 0 ]
725
+
726
+ result_1 = curve_func (y_true , y_score , sample_weight = sample_weight )
727
+ result_2 = curve_func (y_true [:- 1 ], y_score [:- 1 ],
728
+ sample_weight = sample_weight [:- 1 ])
729
+
730
+ for arr_1 , arr_2 in zip (result_1 , result_2 ):
731
+ assert_allclose (arr_1 , arr_2 )
732
+
733
+
719
734
def test_precision_recall_curve ():
720
- y_true , _ , probas_pred = make_prediction (binary = True )
721
- _test_precision_recall_curve (y_true , probas_pred )
735
+ y_true , _ , y_score = make_prediction (binary = True )
736
+ _test_precision_recall_curve (y_true , y_score )
722
737
723
738
# Use {-1, 1} for labels; make sure original labels aren't modified
724
739
y_true [np .where (y_true == 0 )] = - 1
725
740
y_true_copy = y_true .copy ()
726
- _test_precision_recall_curve (y_true , probas_pred )
741
+ _test_precision_recall_curve (y_true , y_score )
727
742
assert_array_equal (y_true_copy , y_true )
728
743
729
744
labels = [1 , 0 , 0 , 1 ]
@@ -736,31 +751,24 @@ def test_precision_recall_curve():
736
751
assert p .size == t .size + 1
737
752
738
753
739
- def _test_precision_recall_curve (y_true , probas_pred ):
754
+ def _test_precision_recall_curve (y_true , y_score ):
740
755
# Test Precision-Recall and aread under PR curve
741
- p , r , thresholds = precision_recall_curve (y_true , probas_pred )
742
- precision_recall_auc = _average_precision_slow (y_true , probas_pred )
756
+ p , r , thresholds = precision_recall_curve (y_true , y_score )
757
+ precision_recall_auc = _average_precision_slow (y_true , y_score )
743
758
assert_array_almost_equal (precision_recall_auc , 0.859 , 3 )
744
759
assert_array_almost_equal (precision_recall_auc ,
745
- average_precision_score (y_true , probas_pred ))
760
+ average_precision_score (y_true , y_score ))
746
761
# `_average_precision` is not very precise in case of 0.5 ties: be tolerant
747
- assert_almost_equal (_average_precision (y_true , probas_pred ),
762
+ assert_almost_equal (_average_precision (y_true , y_score ),
748
763
precision_recall_auc , decimal = 2 )
749
764
assert p .size == r .size
750
765
assert p .size == thresholds .size + 1
751
766
# Smoke test in the case of proba having only one value
752
- p , r , thresholds = precision_recall_curve (y_true ,
753
- np .zeros_like (probas_pred ))
767
+ p , r , thresholds = precision_recall_curve (y_true , np .zeros_like (y_score ))
754
768
assert p .size == r .size
755
769
assert p .size == thresholds .size + 1
756
770
757
771
758
- def test_precision_recall_curve_errors ():
759
- # Contains non-binary labels
760
- with pytest .raises (ValueError ):
761
- precision_recall_curve ([0 , 1 , 2 ], [[0.0 ], [1.0 ], [1.0 ]])
762
-
763
-
764
772
def test_precision_recall_curve_toydata ():
765
773
with np .errstate (all = "raise" ):
766
774
# Binary classification
@@ -913,20 +921,20 @@ def test_score_scale_invariance():
913
921
# This test was expanded (added scaled_down) in response to github
914
922
# issue #3864 (and others), where overly aggressive rounding was causing
915
923
# problems for users with very small y_score values
916
- y_true , _ , probas_pred = make_prediction (binary = True )
924
+ y_true , _ , y_score = make_prediction (binary = True )
917
925
918
- roc_auc = roc_auc_score (y_true , probas_pred )
919
- roc_auc_scaled_up = roc_auc_score (y_true , 100 * probas_pred )
920
- roc_auc_scaled_down = roc_auc_score (y_true , 1e-6 * probas_pred )
921
- roc_auc_shifted = roc_auc_score (y_true , probas_pred - 10 )
926
+ roc_auc = roc_auc_score (y_true , y_score )
927
+ roc_auc_scaled_up = roc_auc_score (y_true , 100 * y_score )
928
+ roc_auc_scaled_down = roc_auc_score (y_true , 1e-6 * y_score )
929
+ roc_auc_shifted = roc_auc_score (y_true , y_score - 10 )
922
930
assert roc_auc == roc_auc_scaled_up
923
931
assert roc_auc == roc_auc_scaled_down
924
932
assert roc_auc == roc_auc_shifted
925
933
926
- pr_auc = average_precision_score (y_true , probas_pred )
927
- pr_auc_scaled_up = average_precision_score (y_true , 100 * probas_pred )
928
- pr_auc_scaled_down = average_precision_score (y_true , 1e-6 * probas_pred )
929
- pr_auc_shifted = average_precision_score (y_true , probas_pred - 10 )
934
+ pr_auc = average_precision_score (y_true , y_score )
935
+ pr_auc_scaled_up = average_precision_score (y_true , 100 * y_score )
936
+ pr_auc_scaled_down = average_precision_score (y_true , 1e-6 * y_score )
937
+ pr_auc_shifted = average_precision_score (y_true , y_score - 10 )
930
938
assert pr_auc == pr_auc_scaled_up
931
939
assert pr_auc == pr_auc_scaled_down
932
940
assert pr_auc == pr_auc_shifted
@@ -954,8 +962,7 @@ def test_score_scale_invariance():
954
962
([1 , 0 , 1 ], [0.5 , 0.75 , 1 ], [1 , 1 , 0 ], [0 , 0.5 , 0.5 ]),
955
963
([1 , 0 , 1 ], [0.25 , 0.5 , 0.75 ], [1 , 1 , 0 ], [0 , 0.5 , 0.5 ]),
956
964
])
957
- def test_det_curve_toydata (y_true , y_score ,
958
- expected_fpr , expected_fnr ):
965
+ def test_det_curve_toydata (y_true , y_score , expected_fpr , expected_fnr ):
959
966
# Check on a batch of small examples.
960
967
fpr , fnr , _ = det_curve (y_true , y_score )
961
968
0 commit comments