Last active
January 24, 2020 12:55
-
-
Save riga/fc2fe7352b0d1eb1b307a197fea4cd4c to your computer and use it in GitHub Desktop.
CMSSW TensorFlow 1.13.1 evaluation test for the DeepFlavor tagger
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
""" | |
Test evaluation script for the DeepJet tagger. | |
Input dimensions and names from | |
https://github.com/cms-sw/cmssw/blob/02d4198c0b6615287fd88e9a8ff650aea994412e/RecoBTag/TensorFlow/plugins/DeepFlavourTFJetTagsProducer.cc | |
""" | |
import os | |
import sys | |
import time | |
import numpy as np | |
import tensorflow as tf | |
# when tf 2 is installed, go into tf 1 compatibility mode | |
if tf.__version__.startswith("2."): | |
tf = tf.compat.v1 | |
# helper to create dummy input data | |
def dummy_data(shape, value=0.1): | |
return value * np.ones(shape, dtype=np.float32) | |
# helper to load a constant graph | |
def load_graph(graph_file): | |
graph = tf.Graph() | |
with graph.as_default(): | |
graph_def = tf.GraphDef() | |
with tf.gfile.GFile(graph_file, "rb") as f: | |
graph_def.ParseFromString(f.read()) | |
tf.import_graph_def(graph_def, name="") | |
return graph | |
# print the CMSSW and TF versions | |
print("CMSSW: {}".format(os.getenv("CMSSW_VERSION"))) | |
print("TF : {}\n".format(tf.__version__)) | |
# use the graph file from argv or default to the released file | |
if len(sys.argv) > 1: | |
graph_file = sys.argv[1] | |
else: | |
graph_file = "/cvmfs/cms.cern.ch/slc7_amd64_gcc700/cms/cmssw-patch/CMSSW_10_6_0_patch1/external/slc7_amd64_gcc700/data/RecoBTag/Combined/data/DeepFlavourV03_10X_training/constant_graph.pb" # noqa | |
# load the graph | |
print("loading graph from {}".format(graph_file)) | |
graph = load_graph(graph_file) | |
print("done\n") | |
# start the session | |
print("starting session") | |
sess = tf.Session(graph=graph) | |
print("done\n") | |
def run(batch_size, silent=False): | |
# build random input tensors | |
feed_dict = { | |
"input_1:0": dummy_data((batch_size, 15)), | |
"input_2:0": dummy_data((batch_size, 25, 16)), | |
"input_3:0": dummy_data((batch_size, 25, 6)), | |
"input_4:0": dummy_data((batch_size, 4, 12)), | |
"input_5:0": dummy_data((batch_size, 1)), | |
"cpf_input_batchnorm/keras_learning_phase:0": False, | |
} | |
# run | |
if not silent: | |
print("evaluating the session, batch size {}".format(batch_size)) | |
t0 = time.time() | |
outputs = sess.run(["ID_pred/Softmax:0"], feed_dict=feed_dict) | |
diff = time.time() - t0 | |
if not silent: | |
print("done, took {:.2f} ms\n".format(diff * 1e3)) | |
# print outputs | |
if not silent: | |
print("DeepJet outputs:") | |
print(outputs) | |
return diff | |
# test a single batch | |
run(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment