8000 Enhancement to Confusion Matrix Output Representation for improving readability #19012 by shubhamdo · Pull Request #19190 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Enhancement to Confusion Matrix Output Representation for improving readability #19012 #19190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
10 changes: 10 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,16 @@ Changelog
:class:`calibration.CalibratedClassifierCV can now properly be used on
prefitted pipelines. :pr:`19641` by :user:`Alek Lefebvre <AlekLefebvre>`

:mod:`sklearn.metrics`
............................

- |Enhancement| :func:`metrics.confusion_matrix` now can return a confusion
matrix with labels in form of dict as an option, :pr:`19190`
by :user:`Shubham Shinde <shubhamdo>` and
:user:`Max Kinner <maxkinner>` and
:user:`Varun John <varunjohn786>` and
:user:`Vinayak Parab <vinayak-parab>`

Code and Documentation Contributors
-----------------------------------

Expand Down
56 changes: 46 additions & 10 deletions sklearn/metrics/_classification.py
8000
Original f 8000 ile line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):

@_deprecate_positional_args
def confusion_matrix(y_true, y_pred, *, labels=None, sample_weight=None,
normalize=None):
normalize=None, as_dict=False):
"""Compute confusion matrix to evaluate the accuracy of a classification.

By definition a confusion matrix :math:`C` is such that :math:`C_{i, j}`
Expand Down Expand Up @@ -249,14 +249,26 @@ def confusion_matrix(y_true, y_pred, *, labels=None, sample_weight=None,
conditions or all the population. If None, confusion matrix will not be
normalized.

as_dict : bool, default=False
Returns a confusion matrix in dict representation with labels as keys
('true', 'pred'), it can be easily converted into a unstacked series.
Refer Examples.

Returns
-------
C : ndarray of shape (n_classes, n_classes)
Confusion matrix whose i-th row and j-th
C : ndarray of shape (n_classes, n_classes) OR dict with
length (n_classes x n_classes)

Confusion matrix as a ndarry whose i-th row and j-th
column entry indicates the number of
samples with true label being i-th class
and predicted label being j-th class.

Confusion matrix as a dict whose keys are tuples
as ('true_label', 'predicted_label') and value
indicates the number of samples with true
label and predicted label.

See Also
--------
ConfusionMatrixDisplay.from_estimator : Plot the confusion matrix
Expand All @@ -282,12 +294,26 @@ def confusion_matrix(y_true, y_pred, *, labels=None, sample_weight=None,
[0, 0, 1],
[1, 0, 2]])

Using as_dict parameter as True,

>>> y_true = ["cat", "ant", "cat", "cat", "ant", "bird"]
>>> y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"]
>>> confusion_matrix(y_true, y_pred, labels=["ant", "bird", "cat"])
array([[2, 0, 0],
[0, 0, 1],
[1, 0, 2]])
>>> cm = confusion_matrix(y_true, y_pred,
... labels=["ant", "bird", "cat"], as_dict=True)

{('ant', 'ant'): 2, ('bird', 'ant'): 0, ('cat', 'ant'): 1,
('ant', 'bird'): 0, ('bird', 'bird'): 0, ('cat', 'bird'): 0,
('ant', 'cat'): 0, ('bird', 'cat'): 1, ('cat', 'cat'): 2}

Dict can be converted to unstacked series,

>>> import pandas as pd
>>> pd.Series(cm).unstack()
ant bird cat
ant 2 0 0
bird 0 0 1
cat 1 0 2


In the binary case, we can extract true positives, etc as follows:

Expand Down Expand Up @@ -346,6 +372,15 @@ def confusion_matrix(y_true, y_pred, *, labels=None, sample_weight=None,
shape=(n_labels, n_labels), dtype=dtype,
).toarray()

if as_dict:
label_list = labels.tolist()
cm_lol = cm.tolist()
cm_dict = {(str(label_list[j]), str(label_list[i])): cm_lol[j][i]
for i in range(0, len(label_list))
for j in range(0, len(cm_lol))}

return cm_dict

with np.errstate(all='ignore'):
if normalize == 'true':
cm = cm / cm.sum(axis=1, keepdims=True)
Expand Down Expand Up @@ -1985,7 +2020,7 @@ class 2 1.00 0.67 0.80 3
if labels_given:
warnings.warn(
"labels size, {0}, does not match size of target_names, {1}"
.format(len(labels), len(target_names))
.format(len(labels), len(target_names))
)
else:
raise ValueError(
Expand Down Expand Up @@ -2047,8 +2082,9 @@ class 2 1.00 0.67 0.80 3
else:
if line_heading == 'accuracy':
row_fmt_accuracy = '{:>{width}s} ' + \
' {:>9.{digits}}' * 2 + ' {:>9.{digits}f}' + \
' {:>9}\n'
' {:>9.{digits}}' * 2 +\
' {:>9.{digits}f}' + \
' {:>9}\n'
report += row_fmt_accuracy.format(line_heading, '', '',
*avg[2:], width=width,
digits=digits)
Expand Down
19 changes: 17 additions & 2 deletions sklearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from functools import partial
from itertools import product
from itertools import chain
Expand Down Expand Up @@ -429,7 +428,6 @@ def test(y_true, y_pred):
cm = multilabel_confusion_matrix(y_true, y_pred)
assert_array_equal(cm, [[[17, 8], [3, 22]],
[[22, 3], [8, 17]]])

test(y_true, y_pred)
test([str(y) for y in y_true],
[str(y) for y in y_pred])
Expand Down Expand Up @@ -590,6 +588,23 @@ def test_confusion_matrix_normalize_single_class():
assert not rec


def test_confusion_matrix_pprint():
# Test pprint confusion matrix - binary classification case
y_true, y_pred, _ = make_prediction()

def test(y_true, y_pred):
cm = confusion_matrix(y_true, y_pred, as_dict=True)
print(cm)
assert cm == {('0', '0'): 19, ('1', '0'): 4, ('2', '0'): 0,
('0', '1'): 4, ('1', '1'): 3, ('2', '1'): 2,
('0', '2'): 1, ('1', '2'): 24, ('2', '2'): 18}

test(y_true, y_pred)
test([str(y) for y in y_true],
[str(y) for y in y_pred])



def test_cohen_kappa():
# These label vectors reproduce the contingency matrix from Artstein and
# Poesio (2008), Table 1: np.array([[20, 20], [10, 50]]).
Expand Down
0