Skip to content

Instantly share code, notes, and snippets.

@zachguo
Last active May 31, 2022 17:39
Show Gist options
  • Save zachguo/10296432 to your computer and use it in GitHub Desktop.
Save zachguo/10296432 to your computer and use it in GitHub Desktop.
Pretty print for sklearn confusion matrix
from sklearn.metrics import confusion_matrix
def print_cm(cm, labels, hide_zeroes=False, hide_diagonal=False, hide_threshold=None):
"""pretty print for confusion matrixes"""
columnwidth = max([len(x) for x in labels]+[5]) # 5 is value length
empty_cell = " " * columnwidth
# Print header
print " " + empty_cell,
for label in labels:
print "%{0}s".format(columnwidth) % label,
print
# Print rows
for i, label1 in enumerate(labels):
print " %{0}s".format(columnwidth) % label1,
for j in range(len(labels)):
cell = "%{0}.1f".format(columnwidth) % cm[i, j]
if hide_zeroes:
cell = cell if float(cm[i, j]) != 0 else empty_cell
if hide_diagonal:
cell = cell if i != j else empty_cell
if hide_threshold:
cell = cell if cm[i, j] > hide_threshold else empty_cell
print cell,
print
# first generate with specified labels
labels = [ ... ]
cm = confusion_matrix(ypred, y, labels)
# then print it in a pretty way
print_cm(cm, labels)
@Coruscate5
Copy link

t/p labeling is helpful, thanks

@botbark
Copy link

botbark commented Dec 29, 2019

In many cases you would like to print the confusion matrix in a better format and look and feel than what is provided by scikit learn by default.

https://botbark.com/2019/12/28/visualize-and-print-confusion-matrix/

@beeb
Copy link

beeb commented Jan 23, 2020

So I made a slightly different version that uses python's f-strings (3.6+) to simplify syntax, and that directly takes the y_true and y_pred data, as well as computes the list of labels if none is provided. Type annotated too.

from typing import List, Optional

import numpy as np
from sklearn.metrics import confusion_matrix

def print_confusion_matrix(
    y_true: np.ndarray,
    y_pred: np.ndarray,
    labels: Optional[List] = None,
    hide_zeroes: bool = False,
    hide_diagonal: bool = False,
    hide_threshold: Optional[float] = None,
):
    """Print a nicely formatted confusion matrix with labelled rows and columns.

    Predicted labels are in the top horizontal header, true labels on the vertical header.

    Args:
        y_true (np.ndarray): ground truth labels
        y_pred (np.ndarray): predicted labels
        labels (Optional[List], optional): list of all labels. If None, then all labels present in the data are
            displayed. Defaults to None.
        hide_zeroes (bool, optional): replace zero-values with an empty cell. Defaults to False.
        hide_diagonal (bool, optional): replace true positives (diagonal) with empty cells. Defaults to False.
        hide_threshold (Optional[float], optional): replace values below this threshold with empty cells. Set to None
            to display all values. Defaults to None.
    """
    if labels is None:
        labels = np.unique(np.concatenate((y_true, y_pred)))
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    # find which fixed column width will be used for the matrix
    columnwidth = max(
        [len(str(x)) for x in labels] + [5]
    )  # 5 is the minimum column width, otherwise the longest class name
    empty_cell = ' ' * columnwidth

    # top-left cell of the table that indicates that top headers are predicted classes, left headers are true classes
    padding_fst_cell = (columnwidth - 3) // 2  # double-slash is int division
    fst_empty_cell = padding_fst_cell * ' ' + 't/p' + ' ' * (columnwidth - padding_fst_cell - 3)

    # Print header
    print('    ' + fst_empty_cell, end=' ')
    for label in labels:
        print(f'{label:{columnwidth}}', end=' ')  # right-aligned label padded with spaces to columnwidth

    print()  # newline
    # Print rows
    for i, label in enumerate(labels):
        print(f'    {label:{columnwidth}}', end=' ')  # right-aligned label padded with spaces to columnwidth
        for j in range(len(labels)):
            # cell value padded to columnwidth with spaces and displayed with 1 decimal
            cell = f'{cm[i, j]:{columnwidth}.1f}'
            if hide_zeroes:
                cell = cell if float(cm[i, j]) != 0 else empty_cell
            if hide_diagonal:
                cell = cell if i != j else empty_cell
            if hide_threshold:
                cell = cell if cm[i, j] > hide_threshold else empty_cell
            print(cell, end=' ')
        print()

@hrieke
Copy link

hrieke commented Sep 17, 2020

I know this is old code, but do you have a license you'd like to release this under?
Thanks!

@zachguo
Copy link
Author

zachguo commented Sep 17, 2020

@hrieke Not really, use it however you want.

@hrieke
Copy link

hrieke commented Sep 17, 2020

Thank you for the fast reply, and the "okay to use statement", but it would be nice to have something from the more standard license.
May I suggest then any of the public domain licenses? CC0 1.0 Universal or Public Domain are two good choices.
If not, MIT License, BSD 3-clause, or Apache License are also very nice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment