Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save shashankg7/012cded1748d6774d29d9e41819bb70f to your computer and use it in GitHub Desktop.
Save shashankg7/012cded1748d6774d29d9e41819bb70f to your computer and use it in GitHub Desktop.
Parse and convert scikit-learn classification_report to latex
"""
Code to parse sklearn classification_report
"""
##
import sys
import collections
##
def parse_classification_report(clfreport):
"""
Parse a sklearn classification report into a dict keyed by class name
and containing a tuple (precision, recall, fscore, support) for each class
"""
lines = clfreport.split('\n')
# Remove empty lines
lines = filter(lambda l: not len(l.strip()) == 0, lines)
# Starts with a header, then score for each class and finally an average
header = lines[0]
cls_lines = lines[1:-1]
avg_line = lines[-1]
assert header.split() == ['precision', 'recall', 'f1-score', 'support']
assert avg_line.split()[0] == 'avg'
# We cannot simply use split because class names can have spaces. So instead
# figure the width of the class field by looking at the indentation of the
# precision header
cls_field_width = len(header) - len(header.lstrip())
# Now, collect all the class names and score in a dict
def parse_line(l):
"""Parse a line of classification_report"""
cls_name = l[:cls_field_width].strip()
precision, recall, fscore, support = l[cls_field_width:].split()
precision = float(precision)
recall = float(recall)
fscore = float(fscore)
support = int(support)
return (cls_name, precision, recall, fscore, support)
data = collections.OrderedDict()
for l in cls_lines:
ret = parse_line(l)
cls_name = ret[0]
scores = ret[1:]
data[cls_name] = scores
# average
data['avg'] = parse_line(avg_line)[1:]
return data
#parse_classification_report(clfreport)
##
def report_to_latex_table(data):
out = ""
out += "\\begin{tabular}{c | c c c c}\n"
out += "Class & Precision & Recall & F-score & Support\\\\\n"
out += "\hline\n"
out += "\hline\\\\\n"
for cls, scores in data.iteritems():
out += cls + " & " + " & ".join([str(s) for s in scores])
out += "\\\\\n"
out += "\\end{tabular}"
return out
#print report_to_latex_table(data)
##
if __name__ == '__main__':
with open(sys.argv[1]) as f:
data = parse_classification_report(f.read())
print report_to_latex_table(data)
##
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment