Skip to content

Instantly share code, notes, and snippets.

@andrzejnovak
Created November 5, 2019 10:21
Show Gist options
  • Save andrzejnovak/940f5975ae5541c7bcce2f102e0d835e to your computer and use it in GitHub Desktop.
Save andrzejnovak/940f5975ae5541c7bcce2f102e0d835e to your computer and use it in GitHub Desktop.
ROC Plots
# Remove weird deepcsv events
def cleandf(tdf):
cdf = tdf[(tdf['Jet_btagDeepB'] < 1) & (tdf['Jet_btagDeepC'] < 1)
& (tdf['Jet_btagDeepB'] > 0) & (tdf['Jet_btagDeepC'] > 0) ]#&(tdf.fj_sdmass < mhigh) & (tdf.fj_sdmass>mlow)]
cdf = cdf[:2000000]
return cdf
for df in [df16, df17, df18]:
df['fj_pt'] = df['Jet_pt']
df['truthb'] = (df['Jet_hadronFlavour'] == 5).astype(int)
df['predictb'] = (df['Jet_hadronFlavour'] == 5).astype(int)
df['truthc'] = (df['Jet_hadronFlavour'] == 4).astype(int)
df['predictc'] = (df['Jet_hadronFlavour'] == 4).astype(int)
df['truthudsg'] = (df['Jet_hadronFlavour'] < 4).astype(int)
df['predictudsg'] = (df['Jet_hadronFlavour'] < 4).astype(int)
df['dCvL'] = df['Jet_btagDeepC']/(1 - df['Jet_btagDeepB'])
df['dCvB'] = df['Jet_btagDeepC']/(df['Jet_btagDeepC'] + df['Jet_btagDeepB'])
df['dBvL'] = df['Jet_btagDeepB']/(1 - df['Jet_btagDeepC'])
df['dBvC'] = df['Jet_btagDeepB']/(df['Jet_btagDeepC'] + df['Jet_btagDeepB'])
df['CSVv2BvL'] = df['Jet_btagCSVV2']
for df in [df16, df17, df18]:
#df['dfBvL'] = df['Jet_btagDeepFlavB']
df['dfCvL'] = df['Jet_btagDeepFlavC']/(1 - df['Jet_btagDeepFlavB'])
df['dfCvB'] = df['Jet_btagDeepFlavC']/(df['Jet_btagDeepFlavC'] + df['Jet_btagDeepFlavB'])
df['dfBvL'] = df['Jet_btagDeepFlavB']/(1 - df['Jet_btagDeepFlavC'])
df['dfBvC'] = df['Jet_btagDeepFlavB']/(df['Jet_btagDeepFlavC'] + df['Jet_btagDeepFlavB'])
def roc_input(frame, signal=["HCC"], include = ["HCC", "Light", "gBB", "gCC", "HBB"], norm=False, tagger='fj_doubleb'):
# Bkg def - filter unwanted
bkg = np.zeros(frame.shape[0])
for label in include:
bkg = np.add(bkg, frame['truth'+label].values )
bkg = [bool(x) for x in bkg]
tdf = frame[bkg] #tdf for temporary df
# Signal
truth = np.zeros(tdf.shape[0])
predict = np.zeros(tdf.shape[0])
prednorm = np.zeros(tdf.shape[0])
for label in signal:
truth += tdf['truth'+label].values
predict += tdf['predict'+label].values
for label in include:
prednorm += tdf['predict'+label].values
tag_vals = tdf[tagger].values
if norm == False:
return truth, predict, tag_vals
else:
return truth, np.divide(predict, prednorm), tag_vals
def compare_rocs(dfs=[], names=[], sigs=[["Hcc"]], bkgs=[["Hbb"]], norm=False, pt=[300,2000],
tagger_names=None,
flip = None,
title=None,
ignore = None,
wps = None,
measure_wps = True,
plotSF = False,
supp=False,
paper=False,
flip_anot = None,
plotname="", colors=[0,1], styles=['-','-'], year='2016', use_tagger=[False], log=True):
c_list = ['darkorange', 'steelblue', 'firebrick', 'purple', 'orangered']
f, ax = plt.subplots(figsize=(11, 10))
#f, ax = plt.subplots()
if tagger_names == None: tagger_names = ['fj_doubleb']*len(dfs)
if ignore == None: ignore = [False]*len(dfs)
if flip == None: flip = [False]*len(dfs)
if wps == None: wps = [False]*len(dfs)
for frame, tagger, name, sig, bkg, col, sty, db_id, show_wps, skip, flip_this in zip(dfs, tagger_names, names, sigs, bkgs, colors, styles, use_tagger*len(colors), wps, ignore, flip):
if skip: continue
mlow = 40; mhigh = 200
frame = cut(frame, ptlow=pt[0] , pthigh =pt[1], mlow = mlow, mhigh = mhigh)
truth, predict, db = roc_input(frame, signal=sig, include = sig+bkg, norm=norm, tagger=tagger)
if flip_this: predict = 1 - predict
if not db_id:
fpr, tpr, threshold = roc_curve(truth, predict)
else:
fpr, tpr, threshold = roc_curve(truth, db)
if type(col) == int: color_toplot = c_list[col]
else: color_toplot = col
#if not name.startswith("raw:"): lab = "DeepDouble{}, AUC = {:.1f}\%".format(name, auc(fpr,tpr)*100)
#else: lab = "{}, AUC = {:.1f}\%".format(name[len('raw:'):], auc(fpr,tpr)*100)
if not name.startswith("raw:"): lab = "DeepDouble{}".format(name)
else: lab = "{}".format(name[len('raw:'):])
ax.plot(tpr, fpr, lw=4, label=lab, color=color_toplot, linestyle=sty)
# Annot WPs
if show_wps != False:
effs = []; cut_vals = []
if measure_wps:
if "DeepDoubleB" or "ZHbb" in lab: mts = [0.003,0.005, 0.01, 0.02, 0.05]; wp_ns = ['T2', 'T1', "M2", 'M1', "L"]
elif "DeepDoubleCvL" or "ZHcc" in lab: mts = [0.01, 0.02, 0.05, 0.1]; wp_ns = ['T', "M2", 'M1', "L"]
elif "DeepDoubleCvB" in lab: mts = [0.012, 0.02, 0.05, 0.1]; wp_ns = ['T', "M", 'L', "UL2" "UL1"]
else:
#if "CvL" in lab and "CSV" in lab: effs=[]; mts = []; wp_ns = [r'VH(H $\rightarrow c\bar{c})$' +'\nWorking Point']; cut_vals = [0.4]
#elif "CvB" in lab and "CSV" in lab: effs=[]; mts = []; wp_ns = [r'VH(H $\rightarrow c\bar{c})$' +'\nWorking Point']; cut_vals = [0.2]
if "CvL" in lab and "CSV" in lab: effs=[]; mts = []; wp_ns = [r'CvL = 0.4']; cut_vals = [0.4]
elif "CvB" in lab and "CSV" in lab: effs=[]; mts = []; wp_ns = [r'CvB = 0.2']; cut_vals = [0.2]
elif "DeepDoubleB" in lab: effs = []; mts = []; wp_ns = ['L', 'M1', "M2", 'T1', "T2"]; cut_vals = [0.7,0.86, 0.89, 0.91, 0.92]
elif "DeepDoubleCvL" in lab: effs = []; mts = []; wp_ns = ['L', 'M1', "M2", 'T']; cut_vals = [0.59,0.7, 0.79, 0.83]
elif "ZHbb" in lab: effs = []; mts = []; wp_ns = ['L', 'M1', "M2", 'T1', "T2"]; cut_vals = [0.67,0.90, 0.95, 0.97, 0.98]
if "double-b" in lab: effs = []; mts = []; wp_ns = ['L', "M1", 'M2', "T"]; cut_vals = [0.3,0.6, 0.8, 0.9]
if len(mts) < 1: # Find effs/mts for cuts
for cut_val in cut_vals : # % mistag rate
idx, val = find_nearest(threshold, cut_val)
mts.append(fpr[idx])
effs.append(tpr[idx])
else: # Find cuts/effs for mts - measure WPs
for wp in mts:
idx, val = find_nearest(fpr, wp)
effs.append(tpr[idx])
cut_vals.append(threshold[idx])
print(lab, "WPs:")
print(np.round(cut_vals,3))
print("MTs", np.round(mts,3))
print("Effs", np.round(effs,3))
va = 'bottom'
if show_wps == 'left': annot_offset = (-7, 0); ha='right'
elif show_wps == 'right': annot_offset = (10, 0); ha='left'
elif show_wps == 'top': annot_offset = (0, 7); ha='center'
elif show_wps == 'bottom' or show_wps == 'bot': annot_offset = (0, -10); ha='center'; va = 'top'
elif show_wps == 'bottom-right': annot_offset = (10, -10); ha='left'; va = 'top'
elif show_wps == 'bottom-left': annot_offset = (30, -50); ha='right'; va = 'top'
else: annot_offset = (-7, 7); ha='right'
#if not plotSF:
if True:
for wp_n, wp_x, wp_y in zip(wp_ns, effs, mts):
ax.annotate(wp_n, xy=(wp_x, wp_y), xytext=annot_offset, color=color_toplot, fontweight='bold',
textcoords="offset points", ha=ha, va=va)
ax.plot(effs, mts, color=color_toplot, marker='+', mew=5, ms=20, linewidth=0)#s=400, linewidths=20)
if plotSF:
if "double-b" in lab:
SFdf = pd.read_csv('/home/anovak/Work/PyCFIT/SF_DoubleB/DF_resRun{}_DoubleB.csv'.format(year), index_col=0).filter(like='pt350to2000', axis=0)
elif "BvL" in lab:
SFdf = pd.read_csv('/home/anovak/Work/PyCFIT/SF_DoubleB/DF_resRun{}_DDBvL.csv'.format(year), index_col=0).filter(like='pt350to2000', axis=0)
effs = effs[::-1]
mts = mts[::-1]
elif "ZHbb" in lab:
SFdf = pd.read_csv('/home/anovak/Work/PyCFIT/SF_DoubleB/DF_resRun{}_DeepAK8ZHbb.csv'.format(year), index_col=0).filter(like='pt350to2000', axis=0)
effs = effs[::-1]
mts = mts[::-1]
vSF = SFdf['SF'].to_numpy()
vSFup = SFdf['SF'].to_numpy()+SFdf['Combined up'].to_numpy()
vSFdown = SFdf['SF'].to_numpy()-SFdf['Combined down'].to_numpy()
# newx, xerrs
corr = vSF*effs
corrup = vSFup*effs - vSF*effs
corrdown = vSF*effs - vSFdown*effs
# for wp_n, wp_x, wp_y in zip(wp_ns, corr, mts):
# ax.annotate(wp_n, xy=(wp_x, wp_y), xytext=annot_offset, color=color_toplot, fontweight='bold',
# textcoords="offset points", ha=ha, va=va)
ax.errorbar(corr, mts, yerr=0, xerr=[corrup, corrdown], color=color_toplot, fmt='o', markerfacecolor='none',
markersize = 10, linewidth=3 )
if plotSF:
ax.plot([], [], color='grey', marker='o', markersize = 10, markerfacecolor='none', linewidth=3, label="Data/MC SF Adjusted")
ax.set_xlim(0,1)
ax.set_ylim(0.001,1)
sigs = sorted(list(set([item for sublist in sigs for item in sublist])))
bkgs = sorted(list(set([item for sublist in bkgs for item in sublist])))
#print sigs
#print bkgs
if len(sigs) == 1 and len(sigs[0]) == 3 and sigs[0][0] in ["H", "Z", "g"]:
xlab = '{} \\rightarrow {}'.format(sigs[0][0], sigs[0][-2]+'\\bar{'+sigs[0][-1]+'}')
ax.set_xlabel(r'Tagging efficiency ($\mathrm{}$)'.format('{'+xlab+'}'), ha='right', x=1.0)
else:
xlab = ['{} \\rightarrow {}'.format(l[0], l[-2]+'\\bar{'+l[-1]+'}') if l[0][0] in ["H", "Z", "g"] else l for l in sigs ]
ax.set_xlabel(r'Tagging efficiency ($\mathrm{}$)'.format("{"+", ".join(xlab)+"}"), ha='right', x=1.0)
if len(bkgs) == 1 and len(bkgs[0]) == 3 and bkgs[0][0] in ["H", "Z", "g"]:
ylab = '{} \\rightarrow {}'.format(bkgs[0][0], bkgs[0][-2]+'\\bar{'+bkgs[0][-1]+'}')
ax.set_ylabel(r'Mistagging rate ($\mathrm{}$)'.format('{'+ylab+'}'), ha='right', y=1.0)
else:
ylab = ['{} \\rightarrow {}'.format(l[0], l[-2]+'\\bar{'+l[-1]+'}') if l[0][0] in ["H", "Z", "g"] else l for l in bkgs ]
ax.set_ylabel(r'Mistagging rate ($\mathrm{}$)'.format("{"+" / ".join(ylab)+"}"), ha='right', y=1.0)
import matplotlib.ticker as plticker
ax.xaxis.set_major_locator(plticker.MultipleLocator(base=0.1))
#ax.xaxis.set_minor_locator(plticker.MultipleLocator(base=0.02))
#ax.yaxis.set_minor_locator(plticker.MultipleLocator(base=0.02))
#ax.tick_params(direction='in', axis='both', which='major', labelsize=18, length=12)
#ax.tick_params(direction='in', axis='both', which='minor' , length=6)
#ax.xaxis.set_ticks_position('both')
ax.yaxis.set_ticks_position('both')
ax.grid(which='minor', alpha=0.5, axis='y', linestyle='dotted')
ax.grid(which='major', alpha=0.9, linestyle='dotted')
leg = ax.legend(borderpad=1, frameon=False, loc=2, fontsize=18) #handlelength=1,
legtitle = r"$\mathrm{t\bar{t}}$ events" + "\n"+"AK4Jets "+r"$\mathrm{p_T >}$ "+str(int(round((pt[0]))))+" GeV"
#+ "\n "+str(int(round(mlow)))+" $\mathrm{<\ jet\ m_{sd}\ <}$ "+str(int(round(mhigh)))+" GeV"
leg.set_title(legtitle, prop = {'size':22})
leg.get_title().set_linespacing(1.5)
leg._legend_box.align = "left"
# ax.annotate(r'{} (13 TeV)'.format(year), xy=(1, 1.015), xycoords='axes fraction', fontsize=22, fontname='Helvetica',
# ha='right', annotation_clip=False)
# ax.annotate('$\mathbf{CMS}$', xy=(0.001, 1.015), xycoords='axes fraction', fontname='Helvetica', fontsize=28,
# ha='left', annotation_clip=False)
import mplhep.cms as cms
if paper:
if supp:
ax = cms.cmslabel(ax, year=year, paper=True, supplementary=True)
else:
ax = cms.cmslabel(ax, year=year, paper=True)
else:
ax = cms.cmslabel(ax, year=year)
if title != None: ax.set_title(title)
if log:
ax.semilogy()
#for tick in ax.get_yticklabels():#+ax.get_xticklabels():
# tick.set_fontname("Fira Sans")
print(ax.get_window_extent().transformed(f.dpi_scale_trans.inverted()).width)
print(ax.get_window_extent().transformed(f.dpi_scale_trans.inverted()).height)
def set_size(w,h, ax=None):
""" w, h: width, height in inches """
if not ax: ax=plt.gca()
l = ax.figure.subplotpars.left
r = ax.figure.subplotpars.right
t = ax.figure.subplotpars.top
b = ax.figure.subplotpars.bottom
figw = float(w)/(r-l)
figh = float(h)/(t-b)
ax.figure.set_size_inches(figw, figh)
set_size(8,8)
print(ax.get_window_extent().transformed(f.dpi_scale_trans.inverted()).width)
print(ax.get_window_extent().transformed(f.dpi_scale_trans.inverted()).height)
if len(plotname) > 1:
f.savefig(os.path.join(savedir, "ROCComparison_"+plotname+year+"pt{}-{}".format(pt[0], pt[1])+".pdf"), dpi=400, transparent=True)
f.savefig(os.path.join(savedir, "ROCComparison_"+plotname+year+"pt{}-{}".format(pt[0], pt[1])+".png"), dpi=400, transparent=True)
else:
if norm: f.savefig(os.path.join(savedir, "ROCNormComparison_"+year+"_"+"+".join(names)+"pt{}-{}".format(pt[0], pt[1])+".pdf"), dpi=400, transparent=True)
else: f.savefig(os.path.join(savedir, "ROCComparison_"+year+"_"+"+".join(names)+"pt{}-{}".format(pt[0], pt[1])+".pdf"), dpi=400, transparent=True)
if norm: f.savefig(os.path.join(savedir, "ROCNormComparison_"+year+"_"+"+".join(names)+"pt{}-{}".format(pt[0], pt[1])+".png"), dpi=400, transparent=True)
else: f.savefig(os.path.join(savedir, "ROCComparison_"+year+"_"+"+".join(names)+"pt{}-{}".format(pt[0], pt[1])+".png"), dpi=400 , transparent=True)
#print(ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()).width)
#print(ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()).height)
print(ax.get_window_extent())
plt.show()
savedir = '/home/anovak/Desktop/ctag/'
plt.style.use(['ROOT'])
import warnings
warnings.filterwarnings("ignore")
compare_rocs(dfs=[df16, df16, df16, df16, df16],
names=["raw:DeepCSV - BvL", "raw:DeepCSV - BvC", "raw:DeepJet - BvL", "raw:DeepJet - BvC", "raw:CSVv2"],
tagger_names = ['dBvL', 'dBvC', 'dfBvL', 'dfBvC','Jet_btagCSVV2'],
use_tagger = [True, True, True, True, True],
ignore = [False, False, False, False, False],
colors=['darkblue', 'darkblue', 'red' , 'red', 'green' ],
pt = [20, 2000],
supp=True,
styles=['-', '--', '-', '--', '-'],
sigs=[["b"], ["b"], ["b"] , ["b"], ["b"] ],
bkgs=[["udsg"], ["c"], ["udsg"], ["c"], ["c"]], plotname="test", year="2016")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment