Forked from julienr/sklearn_classif_report_to_latex.py
Created
April 22, 2017 09:37
-
-
Save shashankg7/012cded1748d6774d29d9e41819bb70f to your computer and use it in GitHub Desktop.
Parse and convert scikit-learn classification_report to latex
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Code to parse sklearn classification_report | |
""" | |
## | |
import sys | |
import collections | |
## | |
def parse_classification_report(clfreport): | |
""" | |
Parse a sklearn classification report into a dict keyed by class name | |
and containing a tuple (precision, recall, fscore, support) for each class | |
""" | |
lines = clfreport.split('\n') | |
# Remove empty lines | |
lines = filter(lambda l: not len(l.strip()) == 0, lines) | |
# Starts with a header, then score for each class and finally an average | |
header = lines[0] | |
cls_lines = lines[1:-1] | |
avg_line = lines[-1] | |
assert header.split() == ['precision', 'recall', 'f1-score', 'support'] | |
assert avg_line.split()[0] == 'avg' | |
# We cannot simply use split because class names can have spaces. So instead | |
# figure the width of the class field by looking at the indentation of the | |
# precision header | |
cls_field_width = len(header) - len(header.lstrip()) | |
# Now, collect all the class names and score in a dict | |
def parse_line(l): | |
"""Parse a line of classification_report""" | |
cls_name = l[:cls_field_width].strip() | |
precision, recall, fscore, support = l[cls_field_width:].split() | |
precision = float(precision) | |
recall = float(recall) | |
fscore = float(fscore) | |
support = int(support) | |
return (cls_name, precision, recall, fscore, support) | |
data = collections.OrderedDict() | |
for l in cls_lines: | |
ret = parse_line(l) | |
cls_name = ret[0] | |
scores = ret[1:] | |
data[cls_name] = scores | |
# average | |
data['avg'] = parse_line(avg_line)[1:] | |
return data | |
#parse_classification_report(clfreport) | |
## | |
def report_to_latex_table(data): | |
out = "" | |
out += "\\begin{tabular}{c | c c c c}\n" | |
out += "Class & Precision & Recall & F-score & Support\\\\\n" | |
out += "\hline\n" | |
out += "\hline\\\\\n" | |
for cls, scores in data.iteritems(): | |
out += cls + " & " + " & ".join([str(s) for s in scores]) | |
out += "\\\\\n" | |
out += "\\end{tabular}" | |
return out | |
#print report_to_latex_table(data) | |
## | |
if __name__ == '__main__': | |
with open(sys.argv[1]) as f: | |
data = parse_classification_report(f.read()) | |
print report_to_latex_table(data) | |
## |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment