Skip to content

Instantly share code, notes, and snippets.

@FRodrigues21
Last active January 4, 2023 19:16
Show Gist options
  • Save FRodrigues21/bec41ee4305c027bcdf9987313182e9b to your computer and use it in GitHub Desktop.
Save FRodrigues21/bec41ee4305c027bcdf9987313182e9b to your computer and use it in GitHub Desktop.
Parse and convert scikit-learn classification_report to latex (Python 3 / Booktabs)
"""
Code to parse sklearn classification_report
Original: https://gist.github.com/julienr/6b9b9a03bd8224db7b4f
Modified to work with Python 3 and classification report averages
"""
import sys
import collections
def parse_classification_report(clfreport):
"""
Parse a sklearn classification report into a dict keyed by class name
and containing a tuple (precision, recall, fscore, support) for each class
"""
lines = clfreport.split('\n')
# Remove empty lines
lines = list(filter(lambda l: not len(l.strip()) == 0, lines))
# Starts with a header, then score for each class and finally an average
header = lines[0]
cls_lines = lines[1:-1]
avg_line = lines[-1]
assert header.split() == ['precision', 'recall', 'f1-score', 'support']
assert avg_line.split()[1] == 'avg'
# We cannot simply use split because class names can have spaces. So instead
# figure the width of the class field by looking at the indentation of the
# precision header
cls_field_width = len(header) - len(header.lstrip())
# Now, collect all the class names and score in a dict
def parse_line(l):
"""Parse a line of classification_report"""
cls_name = l[:cls_field_width].strip()
precision, recall, fscore, support = l[cls_field_width:].split()
precision = float(precision)
recall = float(recall)
fscore = float(fscore)
support = int(support)
return (cls_name, precision, recall, fscore, support)
data = collections.OrderedDict()
for l in cls_lines:
ret = parse_line(l)
cls_name = ret[0]
scores = ret[1:]
data[cls_name] = scores
# average
data['avg'] = parse_line(avg_line)[1:]
return data
def report_to_latex_table(data):
avg_split = False
out = ""
out += "\\begin{table}\n"
out += "\\caption{Latex Table from Classification Report}\n"
out += "\\label{table:classification:report}\n"
out += "\\centering\n"
out += "\\begin{tabular}{c | c c c r}\n"
out += "Class & Precision & Recall & F-score & Support\\\\\n"
out += "\midrule\n"
for cls, scores in data.items():
if 'micro' in cls:
out += "\\midrule\n"
out += cls + " & " + " & ".join([str(s) for s in scores])
out += "\\\\\n"
out += "\\end{tabular}\n"
out += "\\end{table}"
return out
if __name__ == '__main__':
with open(sys.argv[1]) as f:
data = parse_classification_report(f.read())
print(report_to_latex_table(data))
precision recall f1-score support
0 0.98 0.88 0.93 14404
1 0.35 0.82 0.50 217
2 0.54 0.77 0.64 502
3 0.32 0.77 0.45 246
4 0.54 0.76 0.63 984
5 0.69 0.78 0.73 23
6 0.59 0.75 0.66 36
7 0.33 0.86 0.48 80
micro avg 0.87 0.87 0.87 16492
macro avg 0.54 0.80 0.63 16492
weighted avg 0.92 0.87 0.88 16492

Command:

python cfreportlatex.py example.txt

Result

\begin{table}
\caption{Latex Table from Classification Report}
\label{table:classification:report}
\centering
\begin{tabular}{c | c c c r}
Class & Precision & Recall & F-score & Support\\
\midrule
0 & 0.98 & 0.88 & 0.93 & 14404\\
1 & 0.35 & 0.82 & 0.5 & 217\\
2 & 0.54 & 0.77 & 0.64 & 502\\
3 & 0.32 & 0.77 & 0.45 & 246\\
4 & 0.54 & 0.76 & 0.63 & 984\\
5 & 0.69 & 0.78 & 0.73 & 23\\
6 & 0.59 & 0.75 & 0.66 & 36\\
7 & 0.33 & 0.86 & 0.48 & 80\\
\midrule
micro avg & 0.87 & 0.87 & 0.87 & 16492\\
macro avg & 0.54 & 0.8 & 0.63 & 16492\\
avg & 0.92 & 0.87 & 0.88 & 16492\\
\end{tabular}
\end{table}
@isitadaptative
Copy link

Had problems with the code since the classification report throws an extra row 'accuracy' with no values under precision and recall columns. Showed me "not enough values to unpack (expected 4, got 2)". Added a line after removing empty lines in parse_clasification_report() to remove the accuracy line from the list.

lines=[s for s in lines if 'accuracy' not in s]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment