diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 7ebe82a39b884..74b20810de6db 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -54,6 +54,10 @@ Changelog value of the ``zero_division`` keyword argument. :pr:`15879` by :user:`Bibhash Chandra Mitra `. +- |Fix| Fixed a bug in :func:`metrics.plot_confusion_matrix` to correctly + pass the `values_format` parameter to the :class:`ConfusionMatrixDisplay` + plot() call. :pr:`15937` by :user:`Stephen Blystone `. + :mod:`sklearn.semi_supervised` .............................. diff --git a/sklearn/metrics/_plot/confusion_matrix.py b/sklearn/metrics/_plot/confusion_matrix.py index 11a456aa635b1..537d2b9f0d838 100644 --- a/sklearn/metrics/_plot/confusion_matrix.py +++ b/sklearn/metrics/_plot/confusion_matrix.py @@ -195,4 +195,5 @@ def plot_confusion_matrix(estimator, X, y_true, labels=None, disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=display_labels) return disp.plot(include_values=include_values, - cmap=cmap, ax=ax, xticks_rotation=xticks_rotation) + cmap=cmap, ax=ax, xticks_rotation=xticks_rotation, + values_format=values_format) diff --git a/sklearn/metrics/_plot/tests/test_plot_confusion_matrix.py b/sklearn/metrics/_plot/tests/test_plot_confusion_matrix.py index 2d53e6bf24dc0..9f708b151b81b 100644 --- a/sklearn/metrics/_plot/tests/test_plot_confusion_matrix.py +++ b/sklearn/metrics/_plot/tests/test_plot_confusion_matrix.py @@ -245,3 +245,22 @@ def test_confusion_matrix_pipeline(pyplot, clf, data, n_classes): assert_allclose(disp.confusion_matrix, cm) assert disp.text_.shape == (n_classes, n_classes) + + +@pytest.mark.parametrize("values_format", ['e', 'n']) +def test_confusion_matrix_text_format(pyplot, data, y_pred, n_classes, + fitted_clf, values_format): + # Make sure plot text is formatted with 'values_format'. + X, y = data + cm = confusion_matrix(y, y_pred) + disp = plot_confusion_matrix(fitted_clf, X, y, + include_values=True, + values_format=values_format) + + assert disp.text_.shape == (n_classes, n_classes) + + expected_text = np.array([format(v, values_format) + for v in cm.ravel()]) + text_text = np.array([ + t.get_text() for t in disp.text_.ravel()]) + assert_array_equal(expected_text, text_text)