Created
January 28, 2020 20:49
-
-
Save jtrive84/919dde6b2536c35ecddf89005733cf2e to your computer and use it in GitHub Desktop.
Model frequency exhibit generator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def modelcomp(df, outpath, titlestr=None): | |
""" | |
Generate partial residual model comparison exhibits. | |
""" | |
try: | |
pdf = PdfPages(outpath) | |
df = df.rename({"INCIDENCE_MNTH":"actual_count"}, axis=1) | |
keepfields = [i for i in df.columns if i.endswith("count")] + ["EXPOSURE"] | |
varcols = [i for i in df.columns if i not in keepfields] | |
cfields = [i for i in df.columns if i.endswith("count")] | |
rfields = [i.replace("count", "rate") for i in cfields] | |
rcapped = [i + "_capped" for i in rfields] | |
lfields = [i.replace("_count", "").strip() for i in cfields] | |
linesty = ["--" if i.startswith("actual") else "-" for i in cfields] | |
colorlist = [ | |
"#2c475c", "#E02C70", "#6EA1D5", "#8e3e95", "#ff9325", "#ff88b8", | |
"#ffce83", "#33f008", "#623381", "#eee646" ,"#50758b", "#389365", | |
"#6EA1D5", | |
] | |
# Create dict to hold parameters for each model to include with exhibits. | |
modelslist = [] | |
for indx, mdldesc in enumerate(lfields): | |
dmodel = {"count":cfields[indx], "rate":rfields[indx], "legend":mdldesc, | |
"color":"#888888" if mdldesc=="actual" else colorlist[indx], | |
"capped":rcapped[indx], | |
"linewidth":1, "linestyle":linesty[indx],} | |
modelslist.append(dmodel) | |
for var in varcols: | |
print(" + Summarizing modeled output via `{}`.".format(var)) | |
dfvar = df[keepfields + [var]].groupby(var, as_index=False).sum() | |
for cstr, rstr in zip(cfields, rfields): | |
dfvar[rstr] = dfvar[cstr] / dfvar["EXPOSURE"] | |
if var=="ATTAINED_AGE": | |
la, ua = 65, 105 | |
dfvar = dfvar[(dfvar["ATTAINED_AGE"]>=la) & (dfvar["ATTAINED_AGE"]<=ua)].reset_index(drop=True) | |
rotation, legendloc, labelsize, font = 90, "upper left", 6, 6 | |
y_axis_min, y_axis_incr, y_axis_max = 0, .02, .32 | |
dfvar["ATTAINED_AGE"] = dfvar["ATTAINED_AGE"].astype(np.int) | |
else: | |
legendloc, rotation, labelsize, font = "lower left", 0, 7, 7 | |
actual_ratis_maximus = dfvar[rfields].values.max() | |
if actual_ratis_maximus<=.05: | |
y_axis_min, y_axis_max, y_axis_incr = 0, .06, .01 | |
elif actual_ratis_maximus>.05 and actual_ratis_maximus<=.10: | |
y_axis_min, y_axis_max, y_axis_incr = 0, .11, .01 | |
elif actual_ratis_maximus>.10 and actual_ratis_maximus<=.12: | |
y_axis_min, y_axis_max, y_axis_incr = 0, .13, .02 | |
elif actual_ratis_maximus>.12 and actual_ratis_maximus<=.15: | |
y_axis_min, y_axis_max, y_axis_incr = 0, .17, .02 | |
elif actual_ratis_maximus>.15: | |
y_axis_min, y_axis_max, y_axis_incr = 0, .22, .02 | |
# Compute rate capped fields to keep plotted rates within frame. | |
for rf, rfc in zip(rfields, rcapped): | |
dfvar[rfc] = dfvar[rf].map( | |
lambda v: np.min([v, y_axis_max - y_axis_incr]) | |
) | |
if titlestr is None: | |
titlestr_ = "Modeled Rates vs. {}".format(var) | |
else: | |
titlestr_ = titlestr + " (vs. {})".format(var) | |
y_expos = dfvar["EXPOSURE"].values.tolist() | |
x_vals = dfvar[var].values.tolist() | |
x_indx = [i for i in range(len(x_vals))] | |
# Initialize plot. | |
fig, ax1 = plt.subplots(1, 1, tight_layout=True); ax2 = ax1.twinx() | |
ax1.set_title(titlestr_, color="red", loc="left", fontsize=7) | |
ax1.bar(x_indx, y_expos, color="#FFFFFF", edgecolor="#000000", linewidth=.70, alpha=.875) | |
add_value_labels(ax1, spacing=3, annotate_font=font, rotation=rotation) | |
# Add scatter points to axis. | |
splist = [ | |
ax2.scatter( | |
x_indx, dfvar[dmdl["capped"]], marker="s", label=dmdl["legend"], | |
edgecolor=dmdl["color"], color=dmdl["color"], s=9 | |
) for dmdl in modelslist | |
] | |
# Add point plot to axis. | |
pplist = [ | |
ax2.plot( | |
x_indx, dfvar[dmdl["capped"]], color=dmdl["color"], | |
linewidth=dmdl["linewidth"], linestyle=dmdl["linestyle"] | |
) for dmdl in modelslist | |
] | |
yrange = np.arange(y_axis_min, y_axis_max, y_axis_incr) | |
yticklabels = ["{:.2f}".format(i) for i in yrange] | |
ax1.set_xticks(x_indx); ax1.set_xticklabels(x_vals, rotation=0) | |
ax2.yaxis.set_ticks(yrange); ax2.yaxis.set_ticklabels(yticklabels) | |
ax1.tick_params( | |
axis="x", which="both", labelsize=labelsize, top=False, | |
bottom=True, labeltop=False, labelbottom=True, direction="out", | |
) | |
ax1.tick_params( | |
axis="y", which="both", left=False, right=False, labelleft=False, | |
labelright=False, length=0, | |
) | |
ax2.tick_params( | |
axis="x", which="both", top=False, bottom=False, labeltop=False, | |
labelbottom=False, length=0, | |
) | |
ax2.tick_params( | |
axis="y", which="both", left=False, right=True, color="red", | |
labelleft=False, labelright=True, labelsize=labelsize, | |
length=0, | |
) | |
ax1.grid(False); ax2.grid(False) | |
ax1.yaxis.set_major_formatter(plt.NullFormatter()) | |
ax2.set_ylim(bottom=0) | |
# Overlay table if number of distinct x_vals is less than 10. | |
if len(x_vals)<=10: | |
actual_rates = dfvar["actual_rate"].values | |
modeled_rates_hdrs = [i for i in rfields if not i.startswith("actual")] | |
max_len = np.max([dfvar[var].map(lambda v: len(str(v))).max() / 215, .05]) | |
row_labels, col_labels = [dmdl["legend"] for dmdl in modelslist], x_vals | |
row_labels_ae = [i + "_a/e" for i in row_labels if not i.lower().startswith("act")] | |
row_labels = row_labels + row_labels_ae | |
table_data = [ | |
["{:.5f}".format(i) for i in dfvar[dmdl["rate"]]] for dmdl in modelslist | |
#for dmdl in modelslist if not dmdl["rate"].startswith("actual") | |
] | |
table_data_ae = [ | |
(actual_rates / dfvar[mrate].values).tolist() for mrate in modeled_rates_hdrs | |
] | |
table_data_ae = [ | |
["{:.5f}".format(i) for i in nested] for nested in table_data_ae | |
] | |
table_data = table_data + table_data_ae | |
expos_fmt = ["{:,.0f}".format(i) for i in y_expos] | |
tblloc = "upper right" | |
rowloc, widths = "left", [max_len] * len(x_vals) | |
plt.table(cellText=table_data, rowLabels=row_labels, colLabels=col_labels, | |
loc=tblloc, colWidths=widths, rowLoc=rowloc, colLoc="center", | |
bbox=None) | |
legend = ax2.legend( | |
handles=splist, loc=legendloc, frameon=True, fontsize="medium", | |
fancybox=True, framealpha=.60 | |
) | |
frame = legend.get_frame(); frame.set_facecolor("#FFFFFF",) | |
plt.savefig(pdf, format="pdf"); plt.close(fig=fig) | |
finally: | |
pdf.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment