8000 ENH warn in classification_report when target_names doesn't equal lab… · sergeyf/scikit-learn@568c998 · GitHub
[go: up one dir, main page]

Skip to content

Commit 568c998

Browse files
Kenneth Myerssergeyf
authored andcommitted
ENH warn in classification_report when target_names doesn't equal labels size (scikit-learn#7802)
* Added warning for classification_report when target_names doesn't equal labels size and tests for such a case.
1 parent d3b73e0 commit 568c998

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

sklearn/metrics/classification.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,6 +1396,12 @@ class 2 1.00 0.67 0.80 3
13961396
else:
13971397
labels = np.asarray(labels)
13981398

1399+
if target_names is not None and len(labels) != len(target_names):
1400+
warnings.warn(
1401+
"labels size, {0}, does not match size of target_names, {1}"
1402+
.format(len(labels), len(target_names))
1403+
)
1404+
13991405
last_line_heading = 'avg / total'
14001406

14011407
if target_names is None:

sklearn/metrics/tests/test_classification.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,18 @@ def test_classification_report_multiclass_with_long_string_label():
722722
assert_equal(report, expected_report)
723723

724724

725+
def test_classification_report_labels_target_names_unequal_length():
726+
y_true = [0, 0, 2, 0, 0]
727+
y_pred = [0, 2, 2, 0, 0]
728+
target_names = ['class 0', 'class 1', 'class 2']
729+
730+
assert_warns_message(UserWarning,
731+
"labels size, 2, does not "
732+
"match size of target_names, 3",
733+
classification_report,
734+
y_true, y_pred, target_names=target_names)
735+
736+
725737
def test_multilabel_classification_report():
726738
n_classes = 4
727739
n_samples = 50

0 commit comments

Comments
 (0)
0