diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index d26c5dd0c347d..08f6d34c6a379 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -324,6 +324,9 @@ Changelog are integral. :pr:`9843` by :user:`Jon Crall `. +- |Fix| :meth:`metrics.ConfusionMatrixDisplay.plot` uses the correct max + for colormap. :pr:`19784` by `Thomas Fan`_. + - |Fix| Samples with zero `sample_weight` values do not affect the results from :func:`metrics.det_curve`, :func:`metrics.precision_recall_curve` and :func:`metrics.roc_curve`. diff --git a/sklearn/metrics/_plot/confusion_matrix.py b/sklearn/metrics/_plot/confusion_matrix.py index 9fcecec775e6e..891ceede25b9d 100644 --- a/sklearn/metrics/_plot/confusion_matrix.py +++ b/sklearn/metrics/_plot/confusion_matrix.py @@ -122,7 +122,7 @@ def plot(self, *, include_values=True, cmap='viridis', n_classes = cm.shape[0] self.im_ = ax.imshow(cm, interpolation='nearest', cmap=cmap) self.text_ = None - cmap_min, cmap_max = self.im_.cmap(0), self.im_.cmap(256) + cmap_min, cmap_max = self.im_.cmap(0), self.im_.cmap(1.0) if include_values: self.text_ = np.empty_like(cm, dtype=object) diff --git a/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py b/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py index ed0bc04117396..b1498afae89ae 100644 --- a/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py +++ b/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py @@ -380,3 +380,17 @@ def test_confusion_matrix_with_unknown_labels(pyplot, constructor_name): display_labels = [tick.get_text() for tick in disp.ax_.get_xticklabels()] expected_labels = [str(i) for i in range(n_classes + 1)] assert_array_equal(expected_labels, display_labels) + + +def test_colormap_max(pyplot): + """Check that the max color is used for the color of the text.""" + + from matplotlib import cm + gray = cm.get_cmap('gray', 1024) + confusion_matrix = np.array([[1.0, 0.0], [0.0, 1.0]]) + + disp = ConfusionMatrixDisplay(confusion_matrix) + disp.plot(cmap=gray) + + color = disp.text_[1, 0].get_color() + assert_allclose(color, [1.0, 1.0, 1.0, 1.0])