Last active
November 24, 2017 09:56
-
-
Save olbat/41bc27ad1da84742c812b42ebf0c5868 to your computer and use it in GitHub Desktop.
Python3 matplotlib script that plots histograms for features of an ML corpus
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
usage: {} < corpus.json > plot.pdf | |
The corpus file must contain one JSON document per line, | |
features must be stored in a field names "{}", | |
classes in a field names "{}". | |
""" | |
import sys | |
import json | |
from collections import defaultdict | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib.ticker import EngFormatter | |
FEATURES_FIELD = "features" | |
CLASS_FIELD = "category" | |
RATIO = (2, 3) | |
# display help if no data on stdin | |
if sys.stdin.isatty(): | |
print(__doc__.format(sys.argv[0], FEATURES_FIELD, CLASS_FIELD)) | |
sys.exit(1) | |
# load corpus from the corpus file | |
features = [] | |
class_features = defaultdict(list) | |
for line in sys.stdin: | |
data = json.loads(line) | |
features.append(data[FEATURES_FIELD]) | |
class_features[data[CLASS_FIELD]].append(data[FEATURES_FIELD]) | |
features = np.transpose(np.array(features)) | |
class_features = { | |
cl: np.transpose(np.array(fts)) | |
for cl, fts in class_features.items()} | |
# setup global figure parameters | |
plotsize = (len(features), len(class_features) + 1) | |
fig = plt.figure(figsize=tuple(x * plotsize[i] for i, x in enumerate(RATIO))) | |
plt.rcParams["axes.grid"] = True | |
plt.rcParams["grid.linestyle"] = "dotted" | |
# display information about the corpus as suptitle | |
desc = "samples:{} (".format(len(features[0])) | |
desc += ", ".join([ | |
"{}:{}".format(cl, len(fts[0])) | |
for cl, fts in class_features.items()]) | |
fig.suptitle(desc + ")") | |
# display features subplots | |
ylims = [] | |
for i, vals in enumerate(features): | |
axis = plt.subplot2grid(plotsize, (i, 0)) | |
axis.yaxis.set_major_formatter(EngFormatter()) | |
if i == 0: | |
plt.title("all") | |
axis.set_ylabel("feature #{}".format(i)) | |
axis.get_yaxis().set_label_coords(-0.15, 0.5) | |
plt.hist(vals) | |
ylims.append(axis.get_ylim()) | |
# display per-class features subplots | |
for i, c in enumerate(class_features): | |
for j, vals in enumerate(class_features[c]): | |
axis = plt.subplot2grid(plotsize, (j, i+1)) | |
axis.yaxis.set_major_formatter(EngFormatter()) | |
if j == 0: | |
plt.title("{}".format(c)) | |
plt.hist(vals) | |
axis.set_ylim(ylims[j]) | |
# fix layout | |
plt.tight_layout() | |
fig.subplots_adjust(top=0.90) | |
# output the resulting plot | |
plt.savefig(sys.stdout.buffer, format="pdf") | |
plt.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment