@@ -915,10 +915,13 @@ def test_precision_recall_curve_toydata():
915
915
916
916
y_true = [0 , 0 ]
917
917
y_score = [0.25 , 0.75 ]
918
- with pytest .raises (Exception ):
919
- precision_recall_curve (y_true , y_score )
920
- with pytest .raises (Exception ):
921
- average_precision_score (y_true , y_score )
918
+ with pytest .warns (UserWarning , match = "No positive class found in y_true" ):
919
+ p , r , _ = precision_recall_curve (y_true , y_score )
920
+ with pytest .warns (UserWarning , match = "No positive class found in y_true" ):
921
+ auc_prc = average_precision_score (y_true , y_score )
922
+ assert_allclose (p , [0 , 1 ])
923
+ assert_allclose (r , [1 , 0 ])
924
+ assert_allclose (auc_prc , 0 )
922
925
923
926
y_true = [1 , 1 ]
924
927
y_score = [0.25 , 0.75 ]
@@ -930,29 +933,33 @@ def test_precision_recall_curve_toydata():
930
933
# Multi-label classification task
931
934
y_true = np .array ([[0 , 1 ], [0 , 1 ]])
932
935
y_score = np .array ([[0 , 1 ], [0 , 1 ]])
933
- with pytest .raises (Exception ):
934
- average_precision_score (y_true , y_score , average = "macro" )
935
- with pytest .raises (Exception ):
936
- average_precision_score (y_true , y_score , average = "weighted" )
937
- assert_almost_equal (
936
+ with pytest .warns (UserWarning , match = "No positive class found in y_true" ):
937
+ assert_allclose (
938
+ average_precision_score (y_true , y_score , average = "macro" ), 0.5
939
+ )
940
+ with pytest .warns (UserWarning , match = "No positive class found in y_true" ):
941
+ assert_allclose (
942
+ average_precision_score (y_true , y_score , average = "weighted" ), 1.0
943
+ )
944
+ assert_allclose (
938
945
average_precision_score (y_true , y_score , average = "samples" ), 1.0
939
946
)
940
- assert_almost_equal (
941
- average_precision_score (y_true , y_score , average
F438
span>= "micro" ), 1.0
942
- )
947
+ assert_allclose (average_precision_score (y_true , y_score , average = "micro" ), 1.0 )
943
948
944
949
y_true = np .array ([[0 , 1 ], [0 , 1 ]])
945
950
y_score = np .array ([[0 , 1 ], [1 , 0 ]])
946
- with pytest .raises (Exception ):
947
- average_precision_score (y_true , y_score , average = "macro" )
948
- with pytest .raises (Exception ):
949
- average_precision_score (y_true , y_score , average = "weighted" )
950
- assert_almost_equal (
951
+ with pytest .warns (UserWarning , match = "No positive class found in y_true" ):
952
+ assert_allclose (
953
+ average_precision_score (y_true , y_score , average = "macro" ), 0.5
954
+ )
955
+ with pytest .warns (UserWarning , match = "No positive class found in y_true" ):
956
+ assert_allclose (
957
+ average_precision_score (y_true , y_score , average = "weighted" ), 1.0
958
+ )
959
+ assert_allclose (
951
960
average_precision_score (y_true , y_score , average = "samples" ), 0.75
952
961
)
953
- assert_almost_equal (
954
- average_precision_score (y_true , y_score , average = "micro" ), 0.5
955
- )
962
+ assert_allclose (average_precision_score (y_true , y_score , average = "micro" ), 0.5 )
956
963
957
964
y_true = np .array ([[1 , 0 ], [0 , 1 ]])
958
965
y_score = np .array ([[0 , 1 ], [1 , 0 ]])
@@ -969,6 +976,35 @@ def test_precision_recall_curve_toydata():
969
976
average_precision_score (y_true , y_score , average = "micro" ), 0.5
970
977
)
971
978
979
+ y_true = np .array ([[0 , 0 ], [0 , 0 ]])
980
+ y_score = np .array ([[0 , 1 ], [0 , 1 ]])
981
+ with pytest .warns (UserWarning , match = "No positive class found in y_true" ):
982
+ assert_allclose (
983
+ average_precision_score (y_true , y_score , average = "macro" ), 0.0
984
+ )
985
+ assert_allclose (
986
+ average_precision_score (y_true , y_score , average = "weighted" ), 0.0
987
+ )
988
+ with pytest .warns (UserWarning , match = "No positive class found in y_true" ):
989
+ assert_allclose (
990
+ average_precision_score (y_true , y_score , average = "samples" ), 0.0
991
+ )
992
+ with pytest .warns (UserWarning , match = "No positive class found in y_true" ):
993
+ assert_allclose (
994
+ average_precision_score (y_true , y_score , average = "micro" ), 0.0
995
+ )
996
+
997
+ y_true = np .array ([[1 , 1 ], [1 , 1 ]])
998
+ y_score = np .array ([[0 , 1 ], [0 , 1 ]])
999
+ assert_allclose (average_precision_score (y_true , y_score , average = "macro" ), 1.0 )
1000
+ assert_allclose (
1001
+ average_precision_score (y_true , y_score , average = "weighted" ), 1.0
1002
+ )
1003
+ assert_allclose (
1004
+ average_precision_score (y_true , y_score , average = "samples" ), 1.0
1005
+ )
1006
+ assert_allclose (average_precision_score (y_true , y_score , average = "micro" ), 1.0 )
1007
+
972
1008
y_true = np .array ([[1 , 0 ], [0 , 1 ]])
973
1009
y_score = np .array ([[0.5 , 0.5 ], [0.5 , 0.5 ]])
974
1010
assert_almost_equal (
@@ -988,9 +1024,10 @@ def test_precision_recall_curve_toydata():
988
1024
# if one class is never present weighted should not be NaN
989
1025
y_true = np .array ([[0 , 0 ], [0 , 1 ]])
990
1026
y_score = np .array ([[0 , 0 ], [0 , 1 ]])
991
- assert_almost_equal (
992
- average_precision_score (y_true , y_score , average = "weighted" ), 1
993
- )
1027
+ with pytest .warns (UserWarning , match = "No positive class found in y_true" ):
1028
+ assert_allclose (
1029
+ average_precision_score (y_true , y_score , average = "weighted" ), 1
1030
+ )
994
1031
995
1032
996
1033
def test_average_precision_constant_values ():
0 commit comments