8000 Merge pull request #4147 from amueller/precision_recall_unsorted_indices · scikit-learn/scikit-learn@92e1e39 · GitHub
[go: up one dir, main page]

Skip to content

Commit 92e1e39

Browse files
committed
Merge pull request #4147 from amueller/precision_recall_unsorted_indices
FIX Sort labels in precision_recall_fscore_support
2 parents ca55e74 + fd9b03a commit 92e1e39

File tree

4 files changed

+42
-1
lines changed

4 files changed

+42
-1
lines changed

doc/whats_new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,10 @@ Bug fixes
276276
:class:`sklearn.naive_bayes.MultinomialNB` and
277277
:class:`sklearn.naive_bayes.BernoulliNB`. By `Trevor Stephens`_.
278278

279+
- Fixed a crash in :func:`metrics.precision_recall_fscore_support`
280+
when using unsorted ``labels`` in the multi-label setting.
281+
By `Andreas Müller`_.
282+
279283
API changes summary
280284
-------------------
281285

sklearn/metrics/classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
847847
if labels is None:
848848
labels 8000 = unique_labels(y_true, y_pred)
849849
else:
850-
labels = np.asarray(labels)
850+
labels = np.sort(labels)
851851

852852
### Calculate tp_sum, pred_sum, true_sum ###
853853

sklearn/metrics/tests/test_classification.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,20 @@ def test_precision_recall_f1_score_multiclass():
303303
assert_array_equal(s, [24, 20, 31])
304304

305305

306+
def test_precision_refcall_f1_score_multilabel_unordered_labels():
307+
# test that labels need not be sorted in the multilabel case
308+
y_true = np.array([[1, 1, 0, 0]])
309+
y_pred = np.array([[0, 0, 1, 1]])
310+
for average in ['samples', 'micro', 'macro', 'weighted', None]:
311+
p, r, f, s = precision_recall_fscore_support(
312+
y_true, y_pred, labels=[4, 1, 2, 3], warn_for=[], average=average)
313+
assert_array_equal(p, 0)
314+
assert_array_equal(r, 0)
315+
assert_array_equal(f, 0)
316+
if average is None:
317+
assert_array_equal(s, [0, 1, 1, 0])
318+
319+
306320
def test_precision_recall_f1_score_multiclass_pos_label_none():
307321
"""Test Precision Recall and F1 Score for multiclass classification task
308322

sklearn/metrics/tests/test_common.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,3 +1064,26 @@ def test_sample_weight_invariance(n_samples=50):
10641064
else:
10651065
yield (check_sample_weight_invariance, name, metric, y_true,
10661066
y_pred)
1067+
1068+
1069+
def test_no_averaging_labels():
1070+
# test labels argument when not using averaging
1071+
# in multi-class and multi-label cases
1072+
y_true_multilabel = np.array([[1, 1, 0, 0], [1, 1, 0, 0]])
1073+
y_pred_multilabel = np.array([[0, 0, 1, 1], [0, 1, 1, 0]])
1074+
y_true_multiclass = np.array([1, 2, 3])
1075+
y_pred_multiclass = np.array([1, 3, 4])
1076+
labels = np.array([4, 1, 2, 3])
1077+
_, inverse_labels = np.unique(labels, return_inverse=True)
1078+
1079+
for name in METRICS_WITH_AVERAGING:
1080+
for y_true, y_pred in [[y_true_multiclass, y_pred_multiclass],
1081+
[y_true_multilabel, y_pred_multilabel]]:
1082+
if name not in MULTILABELS_METRICS and y_pred.shape[1] > 0:
1083+
continue
1084+
1085+
metric = ALL_METRICS[name]
1086+
1087+
score_labels = metric(y_true, y_pred, labels=labels, average=None)
1088+
score = metric(y_true, y_pred, average=None)
1089+
assert_array_equal(score_labels, score[inverse_labels])

0 commit comments

Comments
 (0)
0