Skip to content

Instantly share code, notes, and snippets.

Last active January 24, 2020 12:55
Show Gist options
  • Save riga/fc2fe7352b0d1eb1b307a197fea4cd4c to your computer and use it in GitHub Desktop.
Save riga/fc2fe7352b0d1eb1b307a197fea4cd4c to your computer and use it in GitHub Desktop.
CMSSW TensorFlow 1.13.1 evaluation test for the DeepFlavor tagger
# coding: utf-8
Test evaluation script for the DeepJet tagger.
Input dimensions and names from
import os
import sys
import time
import numpy as np
import tensorflow as tf
# when tf 2 is installed, go into tf 1 compatibility mode
if tf.__version__.startswith("2."):
tf = tf.compat.v1
# helper to create dummy input data
def dummy_data(shape, value=0.1):
return value * np.ones(shape, dtype=np.float32)
# helper to load a constant graph
def load_graph(graph_file):
graph = tf.Graph()
with graph.as_default():
graph_def = tf.GraphDef()
with tf.gfile.GFile(graph_file, "rb") as f:
tf.import_graph_def(graph_def, name="")
return graph
# print the CMSSW and TF versions
print("CMSSW: {}".format(os.getenv("CMSSW_VERSION")))
print("TF : {}\n".format(tf.__version__))
# use the graph file from argv or default to the released file
if len(sys.argv) > 1:
graph_file = sys.argv[1]
graph_file = "/cvmfs/" # noqa
# load the graph
print("loading graph from {}".format(graph_file))
graph = load_graph(graph_file)
# start the session
print("starting session")
sess = tf.Session(graph=graph)
def run(batch_size, silent=False):
# build random input tensors
feed_dict = {
"input_1:0": dummy_data((batch_size, 15)),
"input_2:0": dummy_data((batch_size, 25, 16)),
"input_3:0": dummy_data((batch_size, 25, 6)),
"input_4:0": dummy_data((batch_size, 4, 12)),
"input_5:0": dummy_data((batch_size, 1)),
"cpf_input_batchnorm/keras_learning_phase:0": False,
# run
if not silent:
print("evaluating the session, batch size {}".format(batch_size))
t0 = time.time()
outputs =["ID_pred/Softmax:0"], feed_dict=feed_dict)
diff = time.time() - t0
if not silent:
print("done, took {:.2f} ms\n".format(diff * 1e3))
# print outputs
if not silent:
print("DeepJet outputs:")
return diff
# test a single batch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment