Last active
June 30, 2019 10:54
-
-
Save gaphex/92d292ec92d866019f36df02c02f83a7 to your computer and use it in GitHub Desktop.
Optimize and serialize BERT graph for inference.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import tensorflow as tf | |
| from bert_serving.server.graph import optimize_graph | |
| from bert_serving.server.helper import get_args_parser | |
| MODEL_DIR = '/content/wwm_uncased_L-24_H-1024_A-16/' #@param {type:"string"} | |
| GRAPH_DIR = '/content/graph/' #@param {type:"string"} | |
| GRAPH_OUT = 'extractor.pbtxt' #@param {type:"string"} | |
| GPU_MFRAC = 0.2 #@param {type:"string"} | |
| POOL_STRAT = 'REDUCE_MEAN' #@param {type:"string"} | |
| POOL_LAYER = "-2" #@param {type:"string"} | |
| SEQ_LEN = "64" #@param {type:"string"} | |
| tf.gfile.MkDir(GRAPH_DIR) | |
| parser = get_args_parser() | |
| carg = parser.parse_args(args=['-model_dir', MODEL_DIR, | |
| "-graph_tmp_dir", GRAPH_DIR, | |
| '-max_seq_len', str(SEQ_LEN), | |
| '-pooling_layer', str(POOL_LAYER), | |
| '-pooling_strategy', POOL_STRAT, | |
| '-gpu_memory_fraction', str(GPU_MFRAC)]) | |
| tmpfi_name, config = optimize_graph(carg) | |
| graph_fout = os.path.join(GRAPH_DIR, GRAPH_OUT) | |
| tf.gfile.Rename( | |
| tmpfi_name, | |
| graph_fout, | |
| overwrite=True | |
| ) | |
| print("Serialized graph to {}".format(graph_fout)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment