8000 MAINT Parameters validation for precision_recall_fscore_support (#25681) · scikit-learn/scikit-learn@4b5cf19 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4b5cf19

Browse files
MAINT Parameters validation for precision_recall_fscore_support (#25681)
Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>
1 parent 0d5f434 commit 4b5cf19

File tree

3 files changed

+21
-20
lines changed

3 files changed

+21
-20
lines changed

sklearn/metrics/_classification.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1468,6 +1468,25 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label):
14681468
return labels
14691469

14701470

1471+
@validate_params(
1472+
{
1473+
"y_true": ["array-like", "sparse matrix"],
1474+
"y_pred": ["array-like", "sparse matrix"],
1475+
"beta": [Interval(Real, 0.0, None, closed="both")],
1476+
"labels": ["array-like", None],
1477+
"pos_label": [Real, str, "boolean", None],
1478+
"average": [
1479+
StrOptions({"micro", "macro", "samples", "weighted", "binary"}),
1480+
None,
1481+
],
1482+
"warn_for": [list, tuple, set],
1483+
"sample_weight": ["array-like", None],
1484+
"zero_division": [
1485+
Options(Real, {0, 1}),
1486+
StrOptions({"warn"}),
1487+
],
1488+
}
1489+
)
14711490
def precision_recall_fscore_support(
14721491
y_true,
14731492
y_pred,
@@ -1556,7 +1575,7 @@ def precision_recall_fscore_support(
15561575
meaningful for multilabel classification where this differs from
15571576
:func:`accuracy_score`).
15581577
1559-
warn_for : tuple or set, for internal use
1578+
warn_for : list, tuple or set, for internal use
15601579
This determines which warnings will be made in the case that this
15611580
function is being used to return only one of its metrics.
15621581
@@ -1633,8 +1652,6 @@ def precision_recall_fscore_support(
16331652
array([2, 2, 2]))
16341653
"""
16351654
_check_zero_division(zero_division)
1636-
if beta < 0:
1637-
raise ValueError("beta should be >=0 in the F-beta score")< 8000 /div>
16381655
labels = _check_set_wise_labels(y_true, y_pred, average, labels, pos_label)
16391656

16401657
# Calculate tp_sum, pred_sum, true_sum ###

sklearn/metrics/tests/test_classification.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -386,23 +386,6 @@ def test_average_precision_score_tied_values():
386386
assert average_precision_score(y_true, y_score) != 1.0
387387

388388

389-
@ignore_warnings
390-
def test_precision_recall_fscore_support_errors():
391-
y_true, y_pred, _ = make_prediction(binary=True)
392-
393-
# Bad beta
394-
with pytest.raises(ValueError):
395-
precision_recall_fscore_support(y_true, y_pred, beta=-0.1)
396-
397-
# Bad pos_label
398-
with pytest.raises(ValueError):
399-
precision_recall_fscore_support(y_true, y_pred, pos_label=2, average="binary")
400-
401-
# Bad average option
402-
with pytest.raises(ValueError):
403-
precision_recall_fscore_support([0, 1, 2], [1, 2, 0], average="mega")
404-
405-
406389
def test_precision_recall_f_unused_pos_label():
407390
# Check warning that pos_label unused when set to non-default value
408391
# but average != 'binary'; even if data is binary.

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def _check_function_param_validation(
134134
"sklearn.metrics.multilabel_confusion_matrix",
135135
"sklearn.metrics.mutual_info_score",
136136
"sklearn.metrics.pairwise.additive_chi2_kernel",
137+
"sklearn.metrics.precision_recall_fscore_support",
137138
"sklearn.metrics.r2_score",
138139
"sklearn.metrics.roc_curve",
139140
"sklearn.metrics.zero_one_loss",

0 commit comments

Comments
 (0)
0