Created
November 5, 2019 10:21
-
-
Save andrzejnovak/940f5975ae5541c7bcce2f102e0d835e to your computer and use it in GitHub Desktop.
ROC Plots
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Remove weird deepcsv events | |
def cleandf(tdf): | |
cdf = tdf[(tdf['Jet_btagDeepB'] < 1) & (tdf['Jet_btagDeepC'] < 1) | |
& (tdf['Jet_btagDeepB'] > 0) & (tdf['Jet_btagDeepC'] > 0) ]#&(tdf.fj_sdmass < mhigh) & (tdf.fj_sdmass>mlow)] | |
cdf = cdf[:2000000] | |
return cdf | |
for df in [df16, df17, df18]: | |
df['fj_pt'] = df['Jet_pt'] | |
df['truthb'] = (df['Jet_hadronFlavour'] == 5).astype(int) | |
df['predictb'] = (df['Jet_hadronFlavour'] == 5).astype(int) | |
df['truthc'] = (df['Jet_hadronFlavour'] == 4).astype(int) | |
df['predictc'] = (df['Jet_hadronFlavour'] == 4).astype(int) | |
df['truthudsg'] = (df['Jet_hadronFlavour'] < 4).astype(int) | |
df['predictudsg'] = (df['Jet_hadronFlavour'] < 4).astype(int) | |
df['dCvL'] = df['Jet_btagDeepC']/(1 - df['Jet_btagDeepB']) | |
df['dCvB'] = df['Jet_btagDeepC']/(df['Jet_btagDeepC'] + df['Jet_btagDeepB']) | |
df['dBvL'] = df['Jet_btagDeepB']/(1 - df['Jet_btagDeepC']) | |
df['dBvC'] = df['Jet_btagDeepB']/(df['Jet_btagDeepC'] + df['Jet_btagDeepB']) | |
df['CSVv2BvL'] = df['Jet_btagCSVV2'] | |
for df in [df16, df17, df18]: | |
#df['dfBvL'] = df['Jet_btagDeepFlavB'] | |
df['dfCvL'] = df['Jet_btagDeepFlavC']/(1 - df['Jet_btagDeepFlavB']) | |
df['dfCvB'] = df['Jet_btagDeepFlavC']/(df['Jet_btagDeepFlavC'] + df['Jet_btagDeepFlavB']) | |
df['dfBvL'] = df['Jet_btagDeepFlavB']/(1 - df['Jet_btagDeepFlavC']) | |
df['dfBvC'] = df['Jet_btagDeepFlavB']/(df['Jet_btagDeepFlavC'] + df['Jet_btagDeepFlavB']) | |
def roc_input(frame, signal=["HCC"], include = ["HCC", "Light", "gBB", "gCC", "HBB"], norm=False, tagger='fj_doubleb'): | |
# Bkg def - filter unwanted | |
bkg = np.zeros(frame.shape[0]) | |
for label in include: | |
bkg = np.add(bkg, frame['truth'+label].values ) | |
bkg = [bool(x) for x in bkg] | |
tdf = frame[bkg] #tdf for temporary df | |
# Signal | |
truth = np.zeros(tdf.shape[0]) | |
predict = np.zeros(tdf.shape[0]) | |
prednorm = np.zeros(tdf.shape[0]) | |
for label in signal: | |
truth += tdf['truth'+label].values | |
predict += tdf['predict'+label].values | |
for label in include: | |
prednorm += tdf['predict'+label].values | |
tag_vals = tdf[tagger].values | |
if norm == False: | |
return truth, predict, tag_vals | |
else: | |
return truth, np.divide(predict, prednorm), tag_vals | |
def compare_rocs(dfs=[], names=[], sigs=[["Hcc"]], bkgs=[["Hbb"]], norm=False, pt=[300,2000], | |
tagger_names=None, | |
flip = None, | |
title=None, | |
ignore = None, | |
wps = None, | |
measure_wps = True, | |
plotSF = False, | |
supp=False, | |
paper=False, | |
flip_anot = None, | |
plotname="", colors=[0,1], styles=['-','-'], year='2016', use_tagger=[False], log=True): | |
c_list = ['darkorange', 'steelblue', 'firebrick', 'purple', 'orangered'] | |
f, ax = plt.subplots(figsize=(11, 10)) | |
#f, ax = plt.subplots() | |
if tagger_names == None: tagger_names = ['fj_doubleb']*len(dfs) | |
if ignore == None: ignore = [False]*len(dfs) | |
if flip == None: flip = [False]*len(dfs) | |
if wps == None: wps = [False]*len(dfs) | |
for frame, tagger, name, sig, bkg, col, sty, db_id, show_wps, skip, flip_this in zip(dfs, tagger_names, names, sigs, bkgs, colors, styles, use_tagger*len(colors), wps, ignore, flip): | |
if skip: continue | |
mlow = 40; mhigh = 200 | |
frame = cut(frame, ptlow=pt[0] , pthigh =pt[1], mlow = mlow, mhigh = mhigh) | |
truth, predict, db = roc_input(frame, signal=sig, include = sig+bkg, norm=norm, tagger=tagger) | |
if flip_this: predict = 1 - predict | |
if not db_id: | |
fpr, tpr, threshold = roc_curve(truth, predict) | |
else: | |
fpr, tpr, threshold = roc_curve(truth, db) | |
if type(col) == int: color_toplot = c_list[col] | |
else: color_toplot = col | |
#if not name.startswith("raw:"): lab = "DeepDouble{}, AUC = {:.1f}\%".format(name, auc(fpr,tpr)*100) | |
#else: lab = "{}, AUC = {:.1f}\%".format(name[len('raw:'):], auc(fpr,tpr)*100) | |
if not name.startswith("raw:"): lab = "DeepDouble{}".format(name) | |
else: lab = "{}".format(name[len('raw:'):]) | |
ax.plot(tpr, fpr, lw=4, label=lab, color=color_toplot, linestyle=sty) | |
# Annot WPs | |
if show_wps != False: | |
effs = []; cut_vals = [] | |
if measure_wps: | |
if "DeepDoubleB" or "ZHbb" in lab: mts = [0.003,0.005, 0.01, 0.02, 0.05]; wp_ns = ['T2', 'T1', "M2", 'M1', "L"] | |
elif "DeepDoubleCvL" or "ZHcc" in lab: mts = [0.01, 0.02, 0.05, 0.1]; wp_ns = ['T', "M2", 'M1', "L"] | |
elif "DeepDoubleCvB" in lab: mts = [0.012, 0.02, 0.05, 0.1]; wp_ns = ['T', "M", 'L', "UL2" "UL1"] | |
else: | |
#if "CvL" in lab and "CSV" in lab: effs=[]; mts = []; wp_ns = [r'VH(H $\rightarrow c\bar{c})$' +'\nWorking Point']; cut_vals = [0.4] | |
#elif "CvB" in lab and "CSV" in lab: effs=[]; mts = []; wp_ns = [r'VH(H $\rightarrow c\bar{c})$' +'\nWorking Point']; cut_vals = [0.2] | |
if "CvL" in lab and "CSV" in lab: effs=[]; mts = []; wp_ns = [r'CvL = 0.4']; cut_vals = [0.4] | |
elif "CvB" in lab and "CSV" in lab: effs=[]; mts = []; wp_ns = [r'CvB = 0.2']; cut_vals = [0.2] | |
elif "DeepDoubleB" in lab: effs = []; mts = []; wp_ns = ['L', 'M1', "M2", 'T1', "T2"]; cut_vals = [0.7,0.86, 0.89, 0.91, 0.92] | |
elif "DeepDoubleCvL" in lab: effs = []; mts = []; wp_ns = ['L', 'M1', "M2", 'T']; cut_vals = [0.59,0.7, 0.79, 0.83] | |
elif "ZHbb" in lab: effs = []; mts = []; wp_ns = ['L', 'M1', "M2", 'T1', "T2"]; cut_vals = [0.67,0.90, 0.95, 0.97, 0.98] | |
if "double-b" in lab: effs = []; mts = []; wp_ns = ['L', "M1", 'M2', "T"]; cut_vals = [0.3,0.6, 0.8, 0.9] | |
if len(mts) < 1: # Find effs/mts for cuts | |
for cut_val in cut_vals : # % mistag rate | |
idx, val = find_nearest(threshold, cut_val) | |
mts.append(fpr[idx]) | |
effs.append(tpr[idx]) | |
else: # Find cuts/effs for mts - measure WPs | |
for wp in mts: | |
idx, val = find_nearest(fpr, wp) | |
effs.append(tpr[idx]) | |
cut_vals.append(threshold[idx]) | |
print(lab, "WPs:") | |
print(np.round(cut_vals,3)) | |
print("MTs", np.round(mts,3)) | |
print("Effs", np.round(effs,3)) | |
va = 'bottom' | |
if show_wps == 'left': annot_offset = (-7, 0); ha='right' | |
elif show_wps == 'right': annot_offset = (10, 0); ha='left' | |
elif show_wps == 'top': annot_offset = (0, 7); ha='center' | |
elif show_wps == 'bottom' or show_wps == 'bot': annot_offset = (0, -10); ha='center'; va = 'top' | |
elif show_wps == 'bottom-right': annot_offset = (10, -10); ha='left'; va = 'top' | |
elif show_wps == 'bottom-left': annot_offset = (30, -50); ha='right'; va = 'top' | |
else: annot_offset = (-7, 7); ha='right' | |
#if not plotSF: | |
if True: | |
for wp_n, wp_x, wp_y in zip(wp_ns, effs, mts): | |
ax.annotate(wp_n, xy=(wp_x, wp_y), xytext=annot_offset, color=color_toplot, fontweight='bold', | |
textcoords="offset points", ha=ha, va=va) | |
ax.plot(effs, mts, color=color_toplot, marker='+', mew=5, ms=20, linewidth=0)#s=400, linewidths=20) | |
if plotSF: | |
if "double-b" in lab: | |
SFdf = pd.read_csv('/home/anovak/Work/PyCFIT/SF_DoubleB/DF_resRun{}_DoubleB.csv'.format(year), index_col=0).filter(like='pt350to2000', axis=0) | |
elif "BvL" in lab: | |
SFdf = pd.read_csv('/home/anovak/Work/PyCFIT/SF_DoubleB/DF_resRun{}_DDBvL.csv'.format(year), index_col=0).filter(like='pt350to2000', axis=0) | |
effs = effs[::-1] | |
mts = mts[::-1] | |
elif "ZHbb" in lab: | |
SFdf = pd.read_csv('/home/anovak/Work/PyCFIT/SF_DoubleB/DF_resRun{}_DeepAK8ZHbb.csv'.format(year), index_col=0).filter(like='pt350to2000', axis=0) | |
effs = effs[::-1] | |
mts = mts[::-1] | |
vSF = SFdf['SF'].to_numpy() | |
vSFup = SFdf['SF'].to_numpy()+SFdf['Combined up'].to_numpy() | |
vSFdown = SFdf['SF'].to_numpy()-SFdf['Combined down'].to_numpy() | |
# newx, xerrs | |
corr = vSF*effs | |
corrup = vSFup*effs - vSF*effs | |
corrdown = vSF*effs - vSFdown*effs | |
# for wp_n, wp_x, wp_y in zip(wp_ns, corr, mts): | |
# ax.annotate(wp_n, xy=(wp_x, wp_y), xytext=annot_offset, color=color_toplot, fontweight='bold', | |
# textcoords="offset points", ha=ha, va=va) | |
ax.errorbar(corr, mts, yerr=0, xerr=[corrup, corrdown], color=color_toplot, fmt='o', markerfacecolor='none', | |
markersize = 10, linewidth=3 ) | |
if plotSF: | |
ax.plot([], [], color='grey', marker='o', markersize = 10, markerfacecolor='none', linewidth=3, label="Data/MC SF Adjusted") | |
ax.set_xlim(0,1) | |
ax.set_ylim(0.001,1) | |
sigs = sorted(list(set([item for sublist in sigs for item in sublist]))) | |
bkgs = sorted(list(set([item for sublist in bkgs for item in sublist]))) | |
#print sigs | |
#print bkgs | |
if len(sigs) == 1 and len(sigs[0]) == 3 and sigs[0][0] in ["H", "Z", "g"]: | |
xlab = '{} \\rightarrow {}'.format(sigs[0][0], sigs[0][-2]+'\\bar{'+sigs[0][-1]+'}') | |
ax.set_xlabel(r'Tagging efficiency ($\mathrm{}$)'.format('{'+xlab+'}'), ha='right', x=1.0) | |
else: | |
xlab = ['{} \\rightarrow {}'.format(l[0], l[-2]+'\\bar{'+l[-1]+'}') if l[0][0] in ["H", "Z", "g"] else l for l in sigs ] | |
ax.set_xlabel(r'Tagging efficiency ($\mathrm{}$)'.format("{"+", ".join(xlab)+"}"), ha='right', x=1.0) | |
if len(bkgs) == 1 and len(bkgs[0]) == 3 and bkgs[0][0] in ["H", "Z", "g"]: | |
ylab = '{} \\rightarrow {}'.format(bkgs[0][0], bkgs[0][-2]+'\\bar{'+bkgs[0][-1]+'}') | |
ax.set_ylabel(r'Mistagging rate ($\mathrm{}$)'.format('{'+ylab+'}'), ha='right', y=1.0) | |
else: | |
ylab = ['{} \\rightarrow {}'.format(l[0], l[-2]+'\\bar{'+l[-1]+'}') if l[0][0] in ["H", "Z", "g"] else l for l in bkgs ] | |
ax.set_ylabel(r'Mistagging rate ($\mathrm{}$)'.format("{"+" / ".join(ylab)+"}"), ha='right', y=1.0) | |
import matplotlib.ticker as plticker | |
ax.xaxis.set_major_locator(plticker.MultipleLocator(base=0.1)) | |
#ax.xaxis.set_minor_locator(plticker.MultipleLocator(base=0.02)) | |
#ax.yaxis.set_minor_locator(plticker.MultipleLocator(base=0.02)) | |
#ax.tick_params(direction='in', axis='both', which='major', labelsize=18, length=12) | |
#ax.tick_params(direction='in', axis='both', which='minor' , length=6) | |
#ax.xaxis.set_ticks_position('both') | |
ax.yaxis.set_ticks_position('both') | |
ax.grid(which='minor', alpha=0.5, axis='y', linestyle='dotted') | |
ax.grid(which='major', alpha=0.9, linestyle='dotted') | |
leg = ax.legend(borderpad=1, frameon=False, loc=2, fontsize=18) #handlelength=1, | |
legtitle = r"$\mathrm{t\bar{t}}$ events" + "\n"+"AK4Jets "+r"$\mathrm{p_T >}$ "+str(int(round((pt[0]))))+" GeV" | |
#+ "\n "+str(int(round(mlow)))+" $\mathrm{<\ jet\ m_{sd}\ <}$ "+str(int(round(mhigh)))+" GeV" | |
leg.set_title(legtitle, prop = {'size':22}) | |
leg.get_title().set_linespacing(1.5) | |
leg._legend_box.align = "left" | |
# ax.annotate(r'{} (13 TeV)'.format(year), xy=(1, 1.015), xycoords='axes fraction', fontsize=22, fontname='Helvetica', | |
# ha='right', annotation_clip=False) | |
# ax.annotate('$\mathbf{CMS}$', xy=(0.001, 1.015), xycoords='axes fraction', fontname='Helvetica', fontsize=28, | |
# ha='left', annotation_clip=False) | |
import mplhep.cms as cms | |
if paper: | |
if supp: | |
ax = cms.cmslabel(ax, year=year, paper=True, supplementary=True) | |
else: | |
ax = cms.cmslabel(ax, year=year, paper=True) | |
else: | |
ax = cms.cmslabel(ax, year=year) | |
if title != None: ax.set_title(title) | |
if log: | |
ax.semilogy() | |
#for tick in ax.get_yticklabels():#+ax.get_xticklabels(): | |
# tick.set_fontname("Fira Sans") | |
print(ax.get_window_extent().transformed(f.dpi_scale_trans.inverted()).width) | |
print(ax.get_window_extent().transformed(f.dpi_scale_trans.inverted()).height) | |
def set_size(w,h, ax=None): | |
""" w, h: width, height in inches """ | |
if not ax: ax=plt.gca() | |
l = ax.figure.subplotpars.left | |
r = ax.figure.subplotpars.right | |
t = ax.figure.subplotpars.top | |
b = ax.figure.subplotpars.bottom | |
figw = float(w)/(r-l) | |
figh = float(h)/(t-b) | |
ax.figure.set_size_inches(figw, figh) | |
set_size(8,8) | |
print(ax.get_window_extent().transformed(f.dpi_scale_trans.inverted()).width) | |
print(ax.get_window_extent().transformed(f.dpi_scale_trans.inverted()).height) | |
if len(plotname) > 1: | |
f.savefig(os.path.join(savedir, "ROCComparison_"+plotname+year+"pt{}-{}".format(pt[0], pt[1])+".pdf"), dpi=400, transparent=True) | |
f.savefig(os.path.join(savedir, "ROCComparison_"+plotname+year+"pt{}-{}".format(pt[0], pt[1])+".png"), dpi=400, transparent=True) | |
else: | |
if norm: f.savefig(os.path.join(savedir, "ROCNormComparison_"+year+"_"+"+".join(names)+"pt{}-{}".format(pt[0], pt[1])+".pdf"), dpi=400, transparent=True) | |
else: f.savefig(os.path.join(savedir, "ROCComparison_"+year+"_"+"+".join(names)+"pt{}-{}".format(pt[0], pt[1])+".pdf"), dpi=400, transparent=True) | |
if norm: f.savefig(os.path.join(savedir, "ROCNormComparison_"+year+"_"+"+".join(names)+"pt{}-{}".format(pt[0], pt[1])+".png"), dpi=400, transparent=True) | |
else: f.savefig(os.path.join(savedir, "ROCComparison_"+year+"_"+"+".join(names)+"pt{}-{}".format(pt[0], pt[1])+".png"), dpi=400 , transparent=True) | |
#print(ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()).width) | |
#print(ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()).height) | |
print(ax.get_window_extent()) | |
plt.show() | |
savedir = '/home/anovak/Desktop/ctag/' | |
plt.style.use(['ROOT']) | |
import warnings | |
warnings.filterwarnings("ignore") | |
compare_rocs(dfs=[df16, df16, df16, df16, df16], | |
names=["raw:DeepCSV - BvL", "raw:DeepCSV - BvC", "raw:DeepJet - BvL", "raw:DeepJet - BvC", "raw:CSVv2"], | |
tagger_names = ['dBvL', 'dBvC', 'dfBvL', 'dfBvC','Jet_btagCSVV2'], | |
use_tagger = [True, True, True, True, True], | |
ignore = [False, False, False, False, False], | |
colors=['darkblue', 'darkblue', 'red' , 'red', 'green' ], | |
pt = [20, 2000], | |
supp=True, | |
styles=['-', '--', '-', '--', '-'], | |
sigs=[["b"], ["b"], ["b"] , ["b"], ["b"] ], | |
bkgs=[["udsg"], ["c"], ["udsg"], ["c"], ["c"]], plotname="test", year="2016") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment