diff --git a/sklearn/metrics/classification.py b/sklearn/metrics/classification.py index ee07fa634d080..7d9b938f534b4 100644 --- a/sklearn/metrics/classification.py +++ b/sklearn/metrics/classification.py @@ -1392,6 +1392,12 @@ class 2 1.00 0.67 0.80 3 else: labels = np.asarray(labels) + if target_names is not None and len(labels) != len(target_names): + warnings.warn( + "labels size, {0}, does not match size of target_names, {1}" + .format(len(labels), len(target_names)) + ) + last_line_heading = 'avg / total' if target_names is None: diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index dc8a7c0686b59..e9616e933b70c 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -722,6 +722,18 @@ def test_classification_report_multiclass_with_long_string_label(): assert_equal(report, expected_report) +def test_classification_report_labels_target_names_unequal_length(): + y_true = [0, 0, 2, 0, 0] + y_pred = [0, 2, 2, 0, 0] + target_names = ['class 0', 'class 1', 'class 2'] + + assert_warns_message(UserWarning, + "labels size, 2, does not " + "match size of target_names, 3", + classification_report, + y_true, y_pred, target_names=target_names) + + def test_multilabel_classification_report(): n_classes = 4 n_samples = 50