Skip to content

Instantly share code, notes, and snippets.

@bamos
Created April 3, 2017 11:13
Show Gist options
  • Save bamos/8765f5ee89c3fc261122f5e9a9244da7 to your computer and use it in GitHub Desktop.
Save bamos/8765f5ee89c3fc261122f5e9a9244da7 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import argparse
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('bmh')
import numpy as np
import pandas as pd
import math
from matplotlib import rcParams
rcParams.update({'figure.autolayout': True})
import os
import sys
import json
import glob
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--out', type=str, default='images/qpth-timing')
args = parser.parse_args()
# msesF = os.path.join(args.workDir, 'mses.csv')
# msesDf = pd.read_csv(msesF, sep=',', names=['lam', 'mse'])
# fig, ax = plt.subplots(1, 1, figsize=(5,3))
# # fig.tight_layout()
# plt.plot(msesDf['lam'].values, msesDf['mse'].values)
# plt.xlabel("$\lambda$")
# plt.ylabel("MSE")
# ax.set_yscale('log')
# for ext in ['pdf', 'png']:
# f = args.out + '.' + ext
# fig.savefig(f)
# if ext == 'pdf':
# os.system('pdfcrop "{}" "{}"'.format(f, f))
# print("Created {}".format(f))
batchSzs = [1, 64, 128]
gurobiMeans = [3.60602e-02, 2.33186e+00, 4.66950e+00]
gurobiStds = [3.86304e-03, 2.29599e-02, 2.93522e-02]
singleMeans = [7.75610e-02, 3.30584e+00, 6.64419e+00]
singleStds = [2.73885e-02, 7.14157e-02, 6.35851e-02]
batchedMeans = [6.98230e-02, 1.45570e-01, 1.83456e-01]
batchedStds = [6.51156e-03, 2.80604e-03, 2.77888e-03]
nSamples = len(gurobiMeans)
indices = np.array(list(range(nSamples)))
barWidth = 0.2
cmap = plt.get_cmap("Set1")
colors = cmap(np.linspace(0, 1, 9))
alpha = 0.7
# fig = plt.figure(figsize=(10, 4))
# ax = fig.add_subplot(111)
fig, ax = plt.subplots(1, 1, figsize=(5,3))
plt.bar(indices, gurobiMeans, barWidth,
yerr=gurobiStds, label='Gurobi',
color=colors[0], ecolor='0.3', alpha=alpha)
plt.bar(indices + barWidth, singleMeans, barWidth,
yerr=singleStds, label='qpth (Single)',
color=colors[1], ecolor='0.3', alpha=alpha)
plt.bar(indices + 2 * barWidth, batchedMeans, barWidth,
yerr=batchedStds, label='qpth (Batched)',
color=colors[2], ecolor='0.3', alpha=alpha)
# box = ax.get_position()
# ax.set_position([box.x0, box.y0 + 0.05, box.width, box.height * 0.85])
# plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.25), ncol=4,
# fancybox=True, shadow=True)
plt.ylabel("Runtime (s)")
plt.xlabel("Batch Size")
ax.set_xticks(indices + 1.5 * barWidth)
xticks = []
for batchSz in batchSzs:
xticks.append(batchSz)
ax.set_xticklabels(xticks)
ax.set_yscale('log')
# locs, labels = plt.xticks()
# plt.ylim(0, 1)
for ext in ['pdf', 'png']:
f = args.out + '.' + ext
fig.savefig(f)
if ext == 'pdf':
os.system('pdfcrop "{}" "{}"'.format(f, f))
print("Created {}".format(f))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment