8000 MAINT ensure that pos_label support all possible types (#25317) · betatim/scikit-learn@cd25abe · GitHub
[go: up one dir, main page]

Skip to content

Commit cd25abe

Browse files
authored
MAINT ensure that pos_label support all possible types (scikit-learn#25317)
1 parent 1714eed commit cd25abe

File tree

4 files changed

+70
-14
lines changed

4 files changed

+70
-14
lines changed

sklearn/metrics/_classification.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,7 @@ def jaccard_score(
732732
labels are column indices. By default, all labels in ``y_true`` and
733733
``y_pred`` are used in sorted order.
734734
735-
pos_label : str or int, default=1
735+
pos_label : int, float, bool or str, default=1
736736
The class to report if ``average='binary'`` and the data is binary.
737737
If the data are multiclass or multilabel, this will be ignored;
738738
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
@@ -1083,7 +1083,7 @@ def f1_score(
10831083
.. versionchanged:: 0.17
10841084
Parameter `labels` improved for multiclass problem.
10851085
1086-
pos_label : str or int, default=1
1086+
pos_label : int, float, bool or str, default=1
10871087
The class to report if ``average='binary'`` and the data is binary.
10881088
If the data are multiclass or multilabel, this will be ignored;
10891089
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
@@ -1231,7 +1231,7 @@ def fbeta_score(
12311231
.. versionchanged:: 0.17
12321232
Parameter `labels` improved for multiclass problem.
12331233
1234-
pos_label : str or int, default=1
1234+
pos_label : int, float, bool or str, default=1
12351235
The class to report if ``average='binary'`` and the data is binary.
12361236
If the data are multiclass or multilabel, this will be ignored;
12371237
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
@@ -1491,7 +1491,7 @@ def precision_recall_fscore_support(
14911491
labels are column indices. By default, all labels in ``y_true`` and
14921492
``y_pred`` are used in sorted order.
14931493
1494-
pos_label : str or int, default=1
1494+
pos_label : int, float, bool or str, default=1
14951495
The class to report if ``average='binary'`` and the data is binary.
14961496
If the data are multiclass or multilabel, this will be ignored;
14971497
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
@@ -1893,7 +1893,7 @@ def precision_score(
18931893
.. versionchanged:: 0.17
18941894
Parameter `labels` improved for multiclass problem.
18951895
1896-
pos_label : str or int, default=1
1896+
pos_label : int, float, bool or str, default=1
18971897
The class to report if ``average='binary'`` and the data is binary.
18981898
If the data are multiclass or multilabel, this will be ignored;
18991899
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
@@ -2034,7 +2034,7 @@ def recall_score(
20342034
.. versionchanged:: 0.17
20352035
Parameter `labels` improved for multiclass problem.
20362036
2037-
pos_label : str or int, default=1
2037+
pos_label : int, float, bool or str, default=1
20382038
The class to report if ``average='binary'`` and the data is binary.
20392039
If the data are multiclass or multilabel, this will be ignored;
20402040
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
@@ -2878,7 +2878,7 @@ def brier_score_loss(y_true, y_prob, *, sample_weight=None, pos_label=None):
28782878
sample_weight : array-like of shape (n_samples,), default=None
28792879
Sample weights.
28802880
2881-
pos_label : int or str, default=None
2881+
pos_label : int, float, bool or str, default=None
28822882
Label of the positive class. `pos_label` will be inferred in the
28832883
following manner:
28842884

sklearn/metrics/_ranking.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import warnings
2323
from functools import partial
24-
from numbers import Real, Integral
24+
from numbers import Real
2525

2626
import numpy as np
2727
from scipy.sparse import csr_matrix, issparse
@@ -252,7 +252,7 @@ def _binary_uninterpolated_average_precision(
252252
{
253253
"y_true": ["array-like"],
254254
"y_score": ["array-like"],
255-
"pos_label": [Integral, str, None],
255+
"pos_label": [Real, str, "boolean", None],
256256
"sample_weight": ["array-like", None],
257257
}
258258
)
@@ -278,7 +278,7 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None):
278278
class, confidence values, or non-thresholded measure of decisions
279279
(as returned by "decision_function" on some classifiers).
280280
281-
pos_label : int or str, default=None
281+
pos_label : int, float, bool or str, default=None
282282
The label of the positive class.
283283
When ``pos_label=None``, if `y_true` is in {-1, 1} or {0, 1},
284284
``pos_label`` is set to 1, otherwise an error will be raised.
@@ -848,7 +848,7 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight
848848
class, or non-thresholded measure of decisions (as returned by
849849
`decision_function` on some classifiers).
850850
851-
pos_label : int or str, default=None
851+
pos_label : int, float, bool or str, default=None
852852
The label of the positive class.
853853
When ``pos_label=None``, if y_true is in {-1, 1} or {0, 1},
854854
``pos_label`` is set to 1, otherwise an error will be raised.

sklearn/metrics/tests/test_classification.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def make_prediction(dataset=None, binary=False):
103103

104104

105105
def test_classification_report_dictionary_output():
106-
107106
# Test performance report with dictionary output
108107
iris = datasets.load_iris()
109108
y_true, y_pred, _ = make_prediction(dataset=iris, binary=False)
@@ -1874,7 +1873,6 @@ def test_prf_warnings():
18741873
# average of per-label scores
18751874
f, w = precision_recall_fscore_support, UndefinedMetricWarning
18761875
for average in [None, "weighted", "macro"]:
1877-
18781876
msg = (
18791877
"Precision and F-score are ill-defined and "
18801878
"being set to 0.0 in labels with no predicted samples."
@@ -1974,7 +1972,6 @@ def test_prf_no_warnings_if_zero_division_set(zero_division):
19741972
# average of per-label scores
19751973
f = precision_recall_fscore_support
19761974
for average in [None, "weighted", "macro"]:
1977-
19781975
assert_no_warnings(
19791976
f, [0, 1, 2], [1, 1, 2], average=average, zero_division=zero_division
19801977
)
@@ -2635,3 +2632,36 @@ def test_balanced_accuracy_score(y_true, y_pred F438 ):
26352632
adjusted = balanced_accuracy_score(y_true, y_pred, adjusted=True)
26362633
chance = balanced_accuracy_score(y_true, np.full_like(y_true, y_true[0]))
26372634
assert adjusted == (balanced - chance) / (1 - chance)
2635+
2636+
2637+
@pytest.mark.parametrize(
2638+
"metric",
2639+
[
2640+
jaccard_score,
2641+
f1_score,
2642+
partial(fbeta_score, beta=0.5),
2643+
precision_recall_fscore_support,
2644+
precision_score,
2645+
recall_score,
2646+
brier_score_loss,
2647+
],
2648+
)
2649+
@pytest.mark.parametrize(
2650+
"classes", [(False, True), (0, 1), (0.0, 1.0), ("zero", "one")]
2651+
)
2652+
def test_classification_metric_pos_label_types(metric, classes):
2653+
"""Check that the metric works with different types of `pos_label`.
2654+
2655+
We can expect `pos_label` to be a bool, an integer, a float, a string.
2656+
No error should be raised for those types.
2657+
"""
2658+
rng = np.random.RandomState(42)
2659+
n_samples, pos_label = 10, classes[-1]
2660+
y_true = rng.choice(classes, size=n_samples, replace=True)
2661+
if metric is brier_score_loss:
2662+
# brier score loss requires probabilities
2663+
y_pred = rng.uniform(size=n_samples)
2664+
else:
2665+
y_pred = y_true.copy()
2666+
result = metric(y_true, y_pred, pos_label=pos_label)
2667+
assert not np.any(np.isnan(result))

sklearn/metrics/tests/test_ranking.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2115,3 +2115,29 @@ def test_label_ranking_avg_precision_score_should_allow_csr_matrix_for_y_true_in
21152115
y_score = np.array([[0.5, 0.9, 0.6], [0, 0, 1]])
21162116
result = label_ranking_average_precision_score(y_true, y_score)
21172117
assert result == pytest.approx(2 / 3)
2118+
2119+
2120+
@pytest.mark.parametrize(
2121+
"metric", [average_precision_score, det_curve, precision_recall_curve, roc_curve]
2122+
)
2123+
@pytest.mark.parametrize(
2124+
"classes", [(False, True), (0, 1), (0.0, 1.0), ("zero", "one")]
2125+
)
2126+
def test_ranking_metric_pos_label_types(metric, classes):
2127+
"""Check that the metric works with different types of `pos_label`.
2128+
2129+
We can expect `pos_label` to be a bool, an integer, a float, a string.
2130+
No error should be raised for those types.
2131+
"""
2132+
rng = np.random.RandomState(42)
2133+
n_samples, pos_label = 10, classes[-1]
2134+
y_true = rng.choice(classes, size=n_samples, replace=True)
2135+
y_proba = rng.rand(n_samples)
2136+
result = metric(y_true, y_proba, pos_label=pos_label)
2137+
if isinstance(result, float):
2138+
assert not np.isnan(result)
2139+
else:
2140+
metric_1, metric_2, thresholds = result
2141+
assert not np.isnan(metric_1).any()
2142+
assert not np.isnan(metric_2).any()
2143+
assert not np.isnan(thresholds).any()

0 commit comments

Comments
 (0)
0