8000 [MRG] BUG Fixes constrast in plot_confusion_matrix by thomasjpfan · Pull Request #15936 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
8000

[MRG] BUG Fixes constrast in plot_confusion_matrix #15936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new/v0.22.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ Changelog
is invalid. Previously, it runs fine with no normalization.
:pr:`15888` by `Hanmin Qin`_.

- |Fix| :func:`metrics.plot_confusion_matrix` now colors the label color
correctly to maximize contrast with its background. :pr:`15936` by
`Thomas Fan`_ and :user:`DizietAsahi`.

:mod:`sklearn.utils`
....................

Expand Down
2 changes: 1 addition & 1 deletion sklearn/metrics/_plot/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def plot(self, include_values=True, cmap='viridis',
values_format = '.2g'

# print text with appropriate color depending on background
thresh = (cm.max() - cm.min()) / 2.
thresh = (cm.max() + cm.min()) / 2.0
for i, j in product(range(n_classes), range(n_classes)):
color = cmap_max if cm[i, j] < thresh else cmap_min
self.text_[i, j] = ax.text(j, i,
Expand Down
18 changes: 16 additions & 2 deletions sklearn/metrics/_plot/tests/test_plot_confusion_matrix.py
Original file line number Diff line number D A1B2 iff line change
Expand Up @@ -200,7 +200,7 @@ def test_confusion_matrix_contrast(pyplot):
assert_allclose(disp.text_[0, 0].get_color(), [0.0, 0.0, 0.0, 1.0])
assert_allclose(disp.text_[1, 1].get_color(), [0.0, 0.0, 0.0, 1.0])

# oof-diagonal text is white
# off-diagonal text is white
assert_allclose(disp.text_[0, 1].get_color(), [1.0, 1.0, 1.0, 1.0])
assert_allclose(disp.text_[1, 0].get_color(), [1.0, 1.0, 1.0, 1.0])

Expand All @@ -209,10 +209,24 @@ def test_confusion_matrix_contrast(pyplot):
assert_allclose(disp.text_[0, 1].get_color(), [0.0, 0.0, 0.0, 1.0])
assert_allclose(disp.text_[1, 0].get_color(), [0.0, 0.0, 0.0, 1.0])

# oof-diagonal text is black
# off-diagonal text is black
assert_allclose(disp.text_[0, 0].get_color(), [1.0, 1.0, 1.0, 1.0])
assert_allclose(disp.text_[1, 1].get_color(), [1.0, 1.0, 1.0, 1.0])

# Regression test for #15920
cm = np.array([[19, 34], [32, 58]])
disp = ConfusionMatrixDisplay(cm, display_labels=[0, 1])

disp.plot(cmap=pyplot.cm.Blues)
min_color = pyplot.cm.Blues(0)
max_color = pyplot.cm.Blues(255)
assert_allclose(disp.text_[0, 0].get_color(), max_color)
assert_allclose(disp.text_[0, 1].get_color(), max_color)
assert_allclose(disp.text_[1, 0].get_color(), max_color)
assert_allclose(disp.text_[1, 1].get_color(), min_color)




@pytest.mark.parametrize(
"clf", [LogisticRegression(),
Expand Down
0