Skip to content

Instantly share code, notes, and snippets.

@gaphex
Last active June 30, 2019 10:54
Show Gist options
  • Select an option

  • Save gaphex/92d292ec92d866019f36df02c02f83a7 to your computer and use it in GitHub Desktop.

Select an option

Save gaphex/92d292ec92d866019f36df02c02f83a7 to your computer and use it in GitHub Desktop.
Optimize and serialize BERT graph for inference.
import os
import tensorflow as tf
from bert_serving.server.graph import optimize_graph
from bert_serving.server.helper import get_args_parser
MODEL_DIR = '/content/wwm_uncased_L-24_H-1024_A-16/' #@param {type:"string"}
GRAPH_DIR = '/content/graph/' #@param {type:"string"}
GRAPH_OUT = 'extractor.pbtxt' #@param {type:"string"}
GPU_MFRAC = 0.2 #@param {type:"string"}
POOL_STRAT = 'REDUCE_MEAN' #@param {type:"string"}
POOL_LAYER = "-2" #@param {type:"string"}
SEQ_LEN = "64" #@param {type:"string"}
tf.gfile.MkDir(GRAPH_DIR)
parser = get_args_parser()
carg = parser.parse_args(args=['-model_dir', MODEL_DIR,
"-graph_tmp_dir", GRAPH_DIR,
'-max_seq_len', str(SEQ_LEN),
'-pooling_layer', str(POOL_LAYER),
'-pooling_strategy', POOL_STRAT,
'-gpu_memory_fraction', str(GPU_MFRAC)])
tmpfi_name, config = optimize_graph(carg)
graph_fout = os.path.join(GRAPH_DIR, GRAPH_OUT)
tf.gfile.Rename(
tmpfi_name,
graph_fout,
overwrite=True
)
print("Serialized graph to {}".format(graph_fout))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment