8000 ENH Add `zero_division` parameter for `accuracy_score` (#29213) · commit-0/scikit-learn@99f0f69 · GitHub
[go: up one dir, main page]

Skip to content

Commit 99f0f69

Browse files
Jaimin020glemaitreadrinjalali
authored
ENH Add zero_division parameter for accuracy_score (scikit-learn#29213)
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai> Co-authored-by: adrinjalali <adrin.jalali@gmail.com>
1 parent 8388a1d commit 99f0f69

File tree

3 files changed

+39
-1
lines changed

3 files changed

+39
-1
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- :func:`sklearn.metrics.accuracy_score` now includes a `zero_division`
2+
parameter to raise a warning when `y_true` and `y_pred` are empty.
3+
By :user:`Jaimin Chauhan <jaimin020>`.

sklearn/metrics/_classification.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,16 @@ def _check_targets(y_true, y_pred):
152152
"y_pred": ["array-like", "sparse matrix"],
153153
"normalize": ["boolean"],
154154
"sample_weight": ["array-like", None],
155+
"zero_division": [
156+
Options(Real, {0.0, 1.0, np.nan}),
157+
StrOptions({"warn"}),
158+
],
155159
},
156160
prefer_skip_nested_validation=True,
157161
)
158-
def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
162+
def accuracy_score(
163+
y_true, y_pred, *, normalize=True, sample_weight=None, zero_division="warn"
164+
):
159165
"""Accuracy classification score.
160166
161167
In multilabel classification, this function computes subset accuracy:
@@ -179,6 +185,13 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
179185
sample_weight : array-like of shape (n_samples,), default=None
180186
Sample weights.
181187
188+
zero_division : {"warn", 0.0, 1.0, np.nan}, default="warn"
189+
Sets the value to return when there is a zero division,
190+
e.g. when `y_true` and `y_pred` are empty.
191+
If set to "warn", returns 0.0 input, but a warning is also raised.
192+
193+
versionadded:: 1.6
194+
182195
Returns
183196
-------
184197
score : float or int
@@ -220,6 +233,16 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
220233
y_true, y_pred = attach_unique(y_true, y_pred)
221234
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
222235
check_consistent_length(y_true, y_pred, sample_weight)
236+
237+
8000 if _num_samples(y_true) == 0:
238+
if zero_division == "warn":
239+
msg = (
240+
"accuracy() is ill-defined and set to 0.0. Use the `zero_division` "
241+
"param to control this behavior."
242+
)
243+
warnings.warn(msg, UndefinedMetricWarning)
244+
return _check_zero_division(zero_division)
245+
223246
if y_type.startswith("multilabel"):
224247
if _is_numpy_namespace(xp):
225248
differing_labels = count_nonzero(y_true - y_pred, axis=1)

sklearn/metrics/tests/test_classification.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,13 +809,19 @@ def test_matthews_corrcoef_nan():
809809
partial(fbeta_score, beta=1),
810810
precision_score,
811811
recall_score,
812+
accuracy_score,
812813
partial(cohen_kappa_score, labels=[0, 1]),
813814
],
814815
)
815816
def test_zero_division_nan_no_warning(metric, y_true, y_pred, zero_division):
816817
"""Check the behaviour of `zero_division` when setting to 0, 1 or np.nan.
817818
No warnings should be raised.
818819
"""
820+
if metric is accuracy_score and len(y_true):
821+
pytest.skip(
822+
reason="zero_division is only used with empty y_true/y_pred for accuracy"
823+
)
824+
819825
with warnings.catch_warnings():
820826
warnings.simplefilter("error")
821827
result = metric(y_true, y_pred, zero_division=zero_division)
@@ -834,13 +840,19 @@ def test_zero_division_nan_no_warning(metric, y_true, y_pred, zero_division):
834840
partial(fbeta_score, beta=1),
835841
precision_score,
836842
recall_score,
843+
accuracy_score,
837844
cohen_kappa_score,
838845
],
839846
)
840847
def test_zero_division_nan_warning(metric, y_true, y_pred):
841848
"""Check the behaviour of `zero_division` when setting to "w 6F50 arn".
842849
A `UndefinedMetricWarning` should be raised.
843850
"""
851+
if metric is accuracy_score and len(y_true):
852+
pytest.skip(
853+
reason="zero_division is only used with empty y_true/y_pred for accuracy"
854+
)
855+
844856
with pytest.warns(UndefinedMetricWarning):
845857
result = metric(y_true, y_pred, zero_division="warn")
846858
assert result == 0.0

0 commit comments

Comments
 (0)
0