8000 FIX Uses the color max for colormap in ConfusionMatrixDisplay (#19784) · scikit-learn/scikit-learn@7f30867 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7f30867

Browse files
authored
FIX Uses the color max for colormap in ConfusionMatrixDisplay (#19784)
1 parent b15e312 commit 7f30867

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

doc/whats_new/v1.0.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,9 @@ Changelog
403403
are integral.
404404
:pr:`9843` by :user:`Jon Crall <Erotemic>`.
405405

406+
- |Fix| :meth:`metrics.ConfusionMatrixDisplay.plot` uses the correct max
407+
for colormap. :pr:`19784` by `Thomas Fan`_.
408+
406409
- |Fix| Samples with zero `sample_weight` values do not affect the results
407410
from :func:`metrics.det_curve`, :func:`metrics.precision_recall_curve`
408411
and :func:`metrics.roc_curve`.

sklearn/metrics/_plot/confusion_matrix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def plot(self, *, include_values=True, cmap='viridis',
119119
n_classes = cm.shape[0]
120120
self.im_ = ax.imshow(cm, interpolation='nearest', cmap=cmap)
121121
self.text_ = None
122-
cmap_min, cmap_max = self.im_.cmap(0), self.im_.cmap(256)
122+
cmap_min, cmap_max = self.im_.cmap(0), self.im_.cmap(1.0)
123123

124124
if include_values:
125125
self.text_ = np.empty_like(cm, dtype=object)

sklearn/metrics/_plot/tests/test_confusion_matrix_display.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,17 @@ def test_confusion_matrix_with_unknown_labels(pyplot, constructor_name):
380380
display_labels = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
381381
expected_labels = [str(i) for i in range(n_classes + 1)]
382382
assert_array_equal(expected_labels, display_labels)
383+
384+
385+
def test_colormap_max(pyplot):
386+
"""Check that the max color is used for the color of the text."""
387+
388+
from matplotlib import cm
389+
gray = cm.get_cmap('gray', 1024)
390+
confusion_matrix = np.array([[1.0, 0.0], [0.0, 1.0]])
391+
392+
disp = ConfusionMatrixDisplay(confusion_matrix)
393+
disp.plot(cmap=gray)
394+
395+
color = disp.text_[1, 0].get_color()
396+
assert_allclose(color, [1.0, 1.0, 1.0, 1.0])

0 commit comments

Comments
 (0)
0