10000 FEA add zero_division to matthews_corrcoef (#28509) · scikit-learn/scikit-learn@ba2dd5d · GitHub
[go: up one dir, main page]

Skip to content

Commit ba2dd5d

Browse files
Redjestmarctorsocglemaitre
authored
FEA add zero_division to matthews_corrcoef (#28509)
Co-authored-by: Marc Torrellas Socastro <marc.torsoc@gmail.com> Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
1 parent 49c5948 commit ba2dd5d

File tree

3 files changed

+53
-12
lines changed

3 files changed

+53
-12
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- Adds `zero_division` to :func:`metrics.matthews_corrcoef`.
2+
When there is a zero division, the metric is undefined and this value is returned.
3+
By :user:`Marc Torrellas Socastro <marctorsoc>` and :user:`Noam Keidar <redjest>`

sklearn/metrics/_classification.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,10 +1015,15 @@ def jaccard_score(
10151015
"y_true": ["array-like"],
10161016
"y_pred": ["array-like"],
10171017
"sample_weight": ["array-like", None],
1018+
"zero_division": [
1019+
Options(Real, {0.0, 1.0}),
1020+
"nan",
1021+
StrOptions({"warn"}),
1022+
],
10181023
},
10191024
prefer_skip_nested_validation=True,
10201025
)
1021-
def matthews_corrcoef(y_true, y_pred, *, sample_weight=None):
1026+
def matthews_corrcoef(y_true, y_pred, *, sample_weight=None, zero_division="warn"):
10221027
"""Compute the Matthews correlation coefficient (MCC).
10231028
10241029
The Matthews correlation coefficient is used in machine learning as a
@@ -1049,6 +1054,13 @@ def matthews_corrcoef(y_true, y_pred, *, sample_weight=None):
10491054
10501055
.. versionadded:: 0.18
10511056
1057+
zero_division : {"warn", 0.0, 1.0, np.nan}, default="warn"
1058+
Sets the value to return when there is a zero division, i.e. when all
1059+
predictions and labels are negative. If set to "warn", this acts like 0,
1060+
but a warning is also raised.
1061+
1062+
.. versionadded:: 1.6
1063+
10521064
Returns
10531065
-------
10541066
mcc : float
@@ -1102,7 +1114,13 @@ def matthews_corrcoef(y_true, y_pred, *, sample_weight=None):
11021114
cov_ytyt = n_samples**2 - np.dot(t_sum, t_sum)
11031115

11041116
if cov_ypyp * cov_ytyt == 0:
1105-
return 0.0
1117+
if zero_division == "warn":
1118+
msg = (
1119+
"Matthews correlation coefficient is ill-defined and being set to 0.0. "
1120+
"Use `zero_division` to control this behaviour."
1121+
)
1122+
warnings.warn(msg, UndefinedMetricWarning, stacklevel=2)
1123+
return _check_zero_division(zero_division)
11061124
else:
11071125
return cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)
11081126

sklearn/metrics/tests/test_classification.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -795,9 +795,22 @@ def test_cohen_kappa():
795795
)
796796

797797

798-
def test_matthews_corrcoef_nan():
799-
assert matthews_corrcoef([0], [1]) == 0.0
800-
assert matthews_corrcoef([0, 0], [0, 1]) == 0.0
798+
@pytest.mark.parametrize("zero_division", ["warn", 0, 1, np.nan])
799+
@pytest.mark.parametrize("y_true, y_pred", [([0], [1]), ([0, 0], [0, 1])])
800+
def test_matthews_corrcoef_zero_division(zero_division, y_true, y_pred):
801+
"""Check the behaviour of `zero_division` in `matthews_corrcoef`."""
802+
expected_result = 0.0 if zero_division == "warn" else zero_division
803+
804+
if zero_division == "warn":
805+
with pytest.warns(UndefinedMetricWarning):
806+
result = matthews_corrcoef(y_true, y_pred, zero_division=zero_division)
807+
else:
808+
result = matthews_corrcoef(y_true, y_pred, zero_division=zero_division)
809+
810+
if np.isnan(expected_result):
811+
assert np.isnan(result)
812+
else:
813+
assert result == expected_result
801814

802815

803816
@pytest.mark.parametrize("zero_division", [0, 1, np.nan])
@@ -924,15 +937,19 @@ def test_matthews_corrcoef():
924937

925938
# For the zero vector case, the corrcoef cannot be calculated and should
926939
# output 0
927-
assert_almost_equal(matthews_corrcoef([0, 0, 0, 0], [0, 0, 0, 0]), 0.0)
940+
assert_almost_equal(
941+
matthews_corrcoef([0, 0, 0, 0], [0, 0, 0, 0], zero_division=0), 0.0
942+
)
928943

929944
# And also for any other vector with 0 variance
930-
assert_almost_equal(matthews_corrcoef(y_true, ["a"] * len(y_true)), 0.0)
945+
assert_almost_equal(
946+
matthews_corrcoef(y_true, ["a"] * len(y_true), zero_division=0), 0.0
947+
)
931948

932949
# These two vectors have 0 correlation and hence mcc should be 0
933950
y_1 = [1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1]
934951
y_2 = [1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1]
935-
assert_almost_equal(matthews_corrcoef(y_1, y_2), 0.0)
952+
assert_almost_equal(matthews_corrcoef(y_1, y_2, zero_division=0), 0.0)
936953

937954
# Check that sample weight is able to selectively exclude
938955
mask = [1] * 10 + [0] * 10
@@ -965,17 +982,17 @@ def test_matthews_corrcoef_multiclass():
965982
# Zero variance will result in an mcc of zero
966983
y_true = [0, 1, 2]
967984
y_pred = [3, 3, 3]
968-
assert_almost_equal(matthews_corrcoef(y_true, y_pred), 0.0)
985+
assert_almost_equal(matthews_corrcoef(y_true, y_pred, zero_division=0), 0.0)
969986

970987
# Also for ground truth with zero variance
971988
y_true = [3, 3, 3]
972989
y_pred = [0, 1, 2]
973-
assert_almost_equal(matthews_corrcoef(y_true, y_pred), 0.0)
990+
assert_almost_equal(matthews_corrcoef(y_true, y_pred, zero_division=0), 0.0)
974991

975992
# These two vectors have 0 correlation and hence mcc should be 0
976993
y_1 = [0, 1, 2, 0, 1, 2, 0, 1, 2]
977994
y_2 = [1, 1, 1, 2, 2, 2, 0, 0, 0]
978-
assert_almost_equal(matthews_corrcoef(y_1, y_2), 0.0)
995+
assert_almost_equal(matthews_corrcoef(y_1, y_2, zero_division=0), 0.0)
979996

980997
# We can test that binary assumptions hold using the multiclass computation
981998
# by masking the weight of samples not in the first two classes
@@ -994,7 +1011,10 @@ def test_matthews_corrcoef_multiclass():
9941011
y_pred = [0, 0, 1, 2]
9951012
sample_weight = [1, 1, 0, 0]
9961013
assert_almost_equal(
997-
matthews_corrcoef(y_true, y_pred, sample_weight=sample_weight), 0.0
1014+
matthews_corrcoef(
1015+
y_true, y_pred, sample_weight=sample_weight, zero_division=0.0
1016+
),
1017+
0.0,
9981018
)
9991019

10001020

0 commit comments

Comments
 (0)
0