8000 FIX f1_score with zero_division=1 uses directly confusion matrix stat… · punndcoder28/scikit-learn@3b06962 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3b06962

Browse files
glemaitreOmarManzoorbetatim
authored
FIX f1_score with zero_division=1 uses directly confusion matrix statistic (scikit-learn#27577)
Co-authored-by: Omar Salman <omar.salman@arbisoft.com> Co-authored-by: Tim Head <betatim@gmail.com>
1 parent cf56e95 commit 3b06962

File tree

3 files changed

+61
-44
lines changed

3 files changed

+61
-44
lines changed

doc/whats_new/v1.4.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,12 @@ Changelog
544544
:func:`sklearn.metrics.zero_one_loss` now support Array API compatible inputs.
545545
:pr:`27137` by :user:`Edoardo Abati <EdAbati>`.
546546

547+
- |Fix| :func:`f1_score` now provides correct values when handling various
548+
cases in which division by zero occurs by using a formulation that does not
549+
depend on the precision and recall values.
550+
:pr:`27577` by :user:`Omar Salman <OmarManzoor>` and
551+
:user:`Guillaume Lemaitre <glemaitre>`.
552+
547553
- |API| Deprecated `needs_threshold` and `needs_proba` from :func:`metrics.make_scorer`.
548554
These parameters will be removed in version 1.6. Instead, use `response_method` that
549555
accepts `"predict"`, `"predict_proba"` or `"decision_function"` or a list of such

sklearn/metrics/_classification.py

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,7 +1434,7 @@ def fbeta_score(
14341434
>>> y_pred_empty = [0, 0, 0, 0, 0, 0]
14351435
>>> fbeta_score(y_true, y_pred_empty,
14361436
... average="macro", zero_division=np.nan, beta=0.5)
1437-
0.38...
1437+
0.12...
14381438
"""
14391439

14401440
_, _, f, _ = precision_recall_fscore_support(
@@ -1482,20 +1482,8 @@ def _prf_divide(
14821482
return result
14831483

14841484
# build appropriate warning
1485-
# E.g. "Precision and F-score are ill-defined and being set to 0.0 in
1486-
# labels with no predicted samples. Use ``zero_division`` parameter to
1487-
# control this behavior."
1488-
1489-
if metric in warn_for and "f-score" in warn_for:
1490-
msg_start = "{0} and F-score are".format(metric.title())
1491-
elif metric in warn_for:
1492-
msg_start = "{0} is".format(metric.title())
1493-
elif "f-score" in warn_for:
1494-
msg_start = "F-score is"
1495-
else:
1496-
return result
1497-
1498-
_warn_prf(average, modifier, msg_start, len(result))
1485+
if metric in warn_for:
1486+
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
14991487

15001488
return result
15011489

@@ -1751,7 +1739,7 @@ def precision_recall_fscore_support(
17511739
array([0., 0., 1.]), array([0. , 0. , 0.8]),
17521740
array([2, 2, 2]))
17531741
"""
1754-
zero_division_value = _check_zero_division(zero_division)
1742+
_check_zero_division(zero_division)
17551743
labels = _check_set_wise_labels(y_true, y_pred, average, labels, pos_label)
17561744

17571745
# Calculate tp_sum, pred_sum, true_sum ###
@@ -1784,26 +1772,25 @@ def precision_recall_fscore_support(
17841772
tp_sum, true_sum, "recall", "true", average, warn_for, zero_division
17851773
)
17861774

1787-
# warn for f-score only if zero_division is warn, it is in warn_for
1788-
# and BOTH prec and rec are ill-defined
1789-
if zero_division == "warn" and ("f-score",) == warn_for:
1790-
if (pred_sum[true_sum == 0] == 0).any():
1791-
_warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
1792-
17931775
if np.isposinf(beta):
17941776
f_score = recall
17951777
elif beta == 0:
17961778
f_score = precision
17971779
else:
17981780
# The score is defined as:
17991781
# score = (1 + beta**2) * precision * recall / (beta**2 * precision + recall)
1800-
# We set to `zero_division_value` if the denominator is 0 **or** if **both**
1801-
# precision and recall are ill-defined.
1802-
denom = beta2 * precision + recall
1803-
mask = np.isclose(denom, 0) | np.isclose(pred_sum + true_sum, 0)
1804-
denom[mask] = 1 # avoid division by 0
1805-
f_score = (1 + beta2) * precision * recall / denom
1806-
f_score[mask] = zero_division_value
1782+
# Therefore, we can express the score in terms of confusion matrix entries as:
1783+
# score = (1 + beta**2) * tp / ((1 + beta**2) * tp + beta**2 * fn + fp)
1784+
denom = beta2 * true_sum + pred_sum
1785+
f_score = _prf_divide(
1786+
(1 + beta2) * tp_sum,
1787+
denom,
1788+
"f-score",
1789+
"true nor predicted",
1790+
average,
1791+
warn_for,
1792+
zero_division,
1793+
)
18071794

18081795
# Average the results
18091796
if average == "weighted":

sklearn/metrics/tests/test_classification.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,7 +1809,7 @@ def test_precision_recall_f1_score_with_an_empty_prediction(
18091809

18101810
assert_array_almost_equal(p, [zero_division_expected, 1.0, 1.0, 0.0], 2)
18111811
assert_array_almost_equal(r, [0.0, 0.5, 1.0, zero_division_expected], 2)
1812-
expected_f = 0 if not np.isnan(zero_division_expected) else np.nan
1812+
expected_f = 0
18131813
assert_array_almost_equal(f, [expected_f, 1 / 1.5, 1, expected_f], 2)
18141814
assert_array_almost_equal(s, [1, 2, 1, 0], 2)
18151815

@@ -1826,7 +1826,7 @@ def test_precision_recall_f1_score_with_an_empty_prediction(
18261826

18271827
assert_almost_equal(p, (2 + value_to_sum) / values_to_average)
18281828
assert_almost_equal(r, (1.5 + value_to_sum) / values_to_average)
1829-
expected_f = (2 / 3 + 1) / (4 if not np.isnan(zero_division_expected) else 2)
1829+
expected_f = (2 / 3 + 1) / 4
18301830
assert_almost_equal(f, expected_f)
18311831
assert s is None
18321832
assert_almost_equal(
@@ -1859,7 +1859,7 @@ def test_precision_recall_f1_score_with_an_empty_prediction(
18591859
)
18601860
assert_almost_equal(p, 3 / 4 if zero_division_expected == 0 else 1.0)
18611861
assert_almost_equal(r, 0.5)
1862-
values_to_average = 4 if not np.isnan(zero_division_expected) else 3
1862+
values_to_average = 4
18631863
assert_almost_equal(f, (2 * 2 / 3 + 1) / values_to_average)
18641864
assert s is None
18651865
assert_almost_equal(
@@ -1877,12 +1877,12 @@ def test_precision_recall_f1_score_with_an_empty_prediction(
18771877
assert_almost_equal(r, 1 / 3)
18781878
assert_almost_equal(f, 1 / 3)
18791879
assert s is None
1880-
expected_result = {1: 0.666, np.nan: 1.0}
1880+
expected_result = 0.333
18811881
assert_almost_equal(
18821882
fbeta_score(
18831883
y_true, y_pred, beta=2, average="samples", zero_division=zero_division
18841884
),
1885-
expected_result.get(zero_division, 0.333),
1885+
expected_result,
18861886
2,
18871887
)
18881888

@@ -2012,7 +2012,7 @@ def test_prf_warnings():
20122012
f, w = precision_recall_fscore_support, UndefinedMetricWarning
20132013
for average in [None, "weighted", "macro"]:
20142014
msg = (
2015-
"Precision and F-score are ill-defined and "
2015+
"Precision is ill-defined and "
20162016
"being set to 0.0 in labels with no predicted samples."
20172017
" Use `zero_division` parameter to control"
20182018
" this behavior."
@@ -2021,7 +2021,7 @@ def test_prf_warnings():
20212021
f([0, 1, 2], [1, 1, 2], average=average)
20222022

20232023
msg = (
2024-
"Recall and F-score are ill-defined and "
2024+
"Recall is ill-defined and "
20252025
"being set to 0.0 in labels with no true samples."
20262026
" Use `zero_division` parameter to control"
20272027
" this behavior."
@@ -2031,7 +2031,7 @@ def test_prf_warnings():
20312031

20322032
# average of per-sample scores
20332033
msg = (
2034-
"Precision and F-score are ill-defined and "
2034+
"Precision is ill-defined and "
20352035
"being set to 0.0 in samples with no predicted labels."
20362036
" Use `zero_division` parameter to control"
20372037
" this behavior."
@@ -2040,7 +2040,7 @@ def test_prf_warnings():
20402040
f(np.array([[1, 0], [1, 0]]), np.array([[1, 0], [0, 0]]), average="samples")
20412041

20422042
msg = (
2043-
"Recall and F-score are ill-defined and "
2043+
"Recall is ill-defined and "
20442044
"being set to 0.0 in samples with no true labels."
20452045
" Use `zero_division` parameter to control"
20462046
" this behavior."
@@ -2050,7 +2050,7 @@ def test_prf_warnings():
20502050

20512051
# single score: micro-average
20522052
msg = (
2053-
"Precision and F-score are ill-defined and "
2053+
"Precision is ill-defined and "
20542054
"being set to 0.0 due to no predicted samples."
20552055
" Use `zero_division` parameter to control"
20562056
" this behavior."
@@ -2059,7 +2059,7 @@ def test_prf_warnings():
20592059
f(np.array([[1, 1], [1, 1]]), np.array([[0, 0], [0, 0]]), average="micro")
20602060

20612061
msg = (
2062-
"Recall and F-score are ill-defined and "
2062+
"Recall is ill-defined and "
20632063
"being set to 0.0 due to no true samples."
20642064
" Use `zero_division` parameter to control"
20652065
" this behavior."
@@ -2069,7 +2069,7 @@ def test_prf_warnings():
20692069

20702070
# single positive label
20712071
msg = (
2072-
"Precision and F-score are ill-defined and "
2072+
"Precision is ill-defined and "
20732073
"being set to 0.0 due to no predicted samples."
20742074
" Use `zero_division` parameter to control"
20752075
" this behavior."
@@ -2078,7 +2078,7 @@ def test_prf_warnings():
20782078
f([1, 1], [-1, -1], average="binary")
20792079

20802080
msg = (
2081-
"Recall and F-score are ill-defined and "
2081+
"Recall is ill-defined and "
20822082
"being set to 0.0 due to no true samples."
20832083
" Use `zero_division` parameter to control"
20842084
" this behavior."
@@ -2090,14 +2090,20 @@ def test_prf_warnings():
20902090
warnings.simplefilter("always")
20912091
precision_recall_fscore_support([0, 0], [0, 0], average="binary")
20922092
msg = (
2093-
"Recall and F-score are ill-defined and "
2093+
"F-score is ill-defined and being set to 0.0 due to no true nor "
2094+
"predicted samples. Use `zero_division` parameter to control this"
2095+
" behavior."
2096+
)
2097+
assert str(record.pop().message) == msg
2098+
msg = (
2099+
"Recall is ill-defined and "
20942100
"being set to 0.0 due to no true samples."
20952101
" Use `zero_division` parameter to control"
20962102
" this behavior."
20972103
)
20982104
assert str(record.pop().message) == msg
20992105
msg = (
2100-
"Precision and F-score are ill-defined and "
2106+
"Precision is ill-defined and "
21012107
"being set to 0.0 due to no predicted samples."
21022108
" Use `zero_division` parameter to control"
21032109
" this behavior."
@@ -2818,6 +2824,24 @@ def test_classification_metric_pos_label_types(metric, classes):
28182824
assert not np.any(np.isnan(result))
28192825

28202826

2827+
@pytest.mark.parametrize(
2828+
"y_true, y_pred, expected_score",
2829+
[
2830+
(np.array([0, 1]), np.array([1, 0]), 0.0),
2831+
(np.array([0, 1]), np.array([0, 1]), 1.0),
2832+
(np.array([0, 1]), np.array([0, 0]), 0.0),
2833+
(np.array([0, 0]), np.array([0, 0]), 1.0),
2834+
],
2835+
)
2836+
def test_f1_for_small_binary_inputs_with_zero_division(y_true, y_pred, expected_score):
2837+
"""Check the behaviour of `zero_division` for f1-score.
2838+
2839+
Non-regression test for:
2840+
https://github.com/scikit-learn/scikit-learn/issues/26965
2841+
"""
2842+
assert f1_score(y_true, y_pred, zero_division=1.0) == pytest.approx(expected_score)
2843+
2844+
28212845
@pytest.mark.parametrize(
28222846
"scoring",
28232847
[

0 commit comments

Comments
 (0)
0