Skip to content

Instantly share code, notes, and snippets.

@riga
Last active January 24, 2020 12:55
Show Gist options
  • Save riga/fc2fe7352b0d1eb1b307a197fea4cd4c to your computer and use it in GitHub Desktop.
Save riga/fc2fe7352b0d1eb1b307a197fea4cd4c to your computer and use it in GitHub Desktop.
CMSSW TensorFlow 1.13.1 evaluation test for the DeepFlavor tagger
# coding: utf-8
"""
Test evaluation script for the DeepJet tagger.
Input dimensions and names from
https://github.com/cms-sw/cmssw/blob/02d4198c0b6615287fd88e9a8ff650aea994412e/RecoBTag/TensorFlow/plugins/DeepFlavourTFJetTagsProducer.cc
"""
import os
import sys
import time
import numpy as np
import tensorflow as tf
# when tf 2 is installed, go into tf 1 compatibility mode
if tf.__version__.startswith("2."):
tf = tf.compat.v1
# helper to create dummy input data
def dummy_data(shape, value=0.1):
return value * np.ones(shape, dtype=np.float32)
# helper to load a constant graph
def load_graph(graph_file):
graph = tf.Graph()
with graph.as_default():
graph_def = tf.GraphDef()
with tf.gfile.GFile(graph_file, "rb") as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name="")
return graph
# print the CMSSW and TF versions
print("CMSSW: {}".format(os.getenv("CMSSW_VERSION")))
print("TF : {}\n".format(tf.__version__))
# use the graph file from argv or default to the released file
if len(sys.argv) > 1:
graph_file = sys.argv[1]
else:
graph_file = "/cvmfs/cms.cern.ch/slc7_amd64_gcc700/cms/cmssw-patch/CMSSW_10_6_0_patch1/external/slc7_amd64_gcc700/data/RecoBTag/Combined/data/DeepFlavourV03_10X_training/constant_graph.pb" # noqa
# load the graph
print("loading graph from {}".format(graph_file))
graph = load_graph(graph_file)
print("done\n")
# start the session
print("starting session")
sess = tf.Session(graph=graph)
print("done\n")
def run(batch_size, silent=False):
# build random input tensors
feed_dict = {
"input_1:0": dummy_data((batch_size, 15)),
"input_2:0": dummy_data((batch_size, 25, 16)),
"input_3:0": dummy_data((batch_size, 25, 6)),
"input_4:0": dummy_data((batch_size, 4, 12)),
"input_5:0": dummy_data((batch_size, 1)),
"cpf_input_batchnorm/keras_learning_phase:0": False,
}
# run
if not silent:
print("evaluating the session, batch size {}".format(batch_size))
t0 = time.time()
outputs = sess.run(["ID_pred/Softmax:0"], feed_dict=feed_dict)
diff = time.time() - t0
if not silent:
print("done, took {:.2f} ms\n".format(diff * 1e3))
# print outputs
if not silent:
print("DeepJet outputs:")
print(outputs)
return diff
# test a single batch
run(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment