8000 FIX Add missing 'values_format' param to disp.plot() in plot_confusio… · scikit-learn/scikit-learn@eb3ad2d · GitHub
[go: up one dir, main page]

Skip to content

Commit eb3ad2d

Browse files
blynotesqinhanmin2014
authored andcommitted
FIX Add missing 'values_format' param to disp.plot() in plot_confusion_matrix (#15937)
1 parent dae52f9 commit eb3ad2d

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

doc/whats_new/v0.22.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ Changelog
5454
value of the ``zero_division`` keyword argument. :pr:`15879`
5555
by :user:`Bibhash Chandra Mitra <Bibyutatsu>`.
5656

57+
- |Fix| Fixed a bug in :func:`metrics.plot_confusion_matrix` to correctly
58+
pass the `values_format` parameter to the :class:`ConfusionMatrixDisplay`
59+
plot() call. :pr:`15937` by :user:`Stephen Blystone <blynotes>`.
60+
5761
:mod:`sklearn.semi_supervised`
5862
..............................
5963

sklearn/metrics/_plot/confusion_matrix.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,4 +195,5 @@ def plot_confusion_matrix(estimator, X, y_true, labels=None,
195195
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
196196
display_labels=display_labels)
197197
return disp.plot(include_values=include_values,
198-
cmap=cmap, ax=ax, xticks_rotation=xticks_rotation)
198+
cmap=cmap, ax=ax, xticks_rotation=xticks_rotation,
199+
values_format=values_format)

sklearn/metrics/_plot/tests/test_plot_confusion_matrix.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,22 @@ def test_confusion_matrix_pipeline(pyplot, clf, data, n_classes):
245245

246246
assert_allclose(disp.confusion_matrix, cm)
247247
assert disp.text_.shape == (n_classes, n_classes)
248+
249+
250+
@pytest.mark.parametrize("values_format", ['e', 'n'])
251+
def test_confusion_matrix_text_format(pyplot, data, y_pred, n_classes,
252+
fitted_clf, values_format):
253+
# Make sure plot text is formatted with 'values_format'.
254+
X, y = data
255+
cm = confusion_matrix(y, y_pred)
256+
disp = plot_confusion_matrix(fitted_clf, X, y,
257+
include_values=True,
258+
values_format=values_format)
259+
260+
assert disp.text_.shape == (n_classes, n_classes)
261+
262+
expected_text = np.array([format(v, values_format)
263+
for v in cm.ravel()])
264+
text_text = np.array([
265+
t.get_text() for t in disp.text_.ravel()])
266+
assert_array_equal(expected_text, text_text)

0 commit comments

Comments
 (0)
0