This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| if BUCKET_NAME: | |
| !gsutil -m cp -r $MODEL_DIR $PRETRAINING_DIR gs://$BUCKET_NAME |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| BUCKET_NAME = "bert_resourses" #@param {type:"string"} | |
| MODEL_DIR = "bert_model" #@param {type:"string"} | |
| PRETRAINING_DIR = "pretraining_data" #@param {type:"string"} | |
| VOC_FNAME = "vocab.txt" #@param {type:"string"} | |
| # Input data pipeline config | |
| TRAIN_BATCH_SIZE = 128 #@param {type:"integer"} | |
| MAX_PREDICTIONS = 20 #@param {type:"integer"} | |
| MAX_SEQ_LENGTH = 128 #@param {type:"integer"} | |
| MASKED_LM_PROB = 0.15 #@param |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| model_fn = model_fn_builder( | |
| bert_config=bert_config, | |
| init_checkpoint=INIT_CHECKPOINT, | |
| learning_rate=LEARNING_RATE, | |
| num_train_steps=TRAIN_STEPS, | |
| num_warmup_steps=10, | |
| use_tpu=USE_TPU, | |
| use_one_hot_embeddings=True) | |
| tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import tensorflow as tf | |
| from bert_serving.server.graph import optimize_graph | |
| from bert_serving.server.helper import get_args_parser | |
| MODEL_DIR = '/content/wwm_uncased_L-24_H-1024_A-16/' #@param {type:"string"} | |
| GRAPH_DIR = '/content/graph/' #@param {type:"string"} | |
| GRAPH_OUT = 'extractor.pbtxt' #@param {type:"string"} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def model_fn(features, mode): | |
| with tf.gfile.GFile(GRAPH_PATH, 'rb') as f: | |
| graph_def = tf.GraphDef() | |
| graph_def.ParseFromString(f.read()) | |
| output = tf.import_graph_def(graph_def, | |
| input_map={k + ':0': features[k] | |
| for k in INPUT_NAMES}, | |
| return_elements=['final_encodes:0']) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| INPUT_NAMES = ['input_ids', 'input_mask', 'input_type_ids'] | |
| bert_tokenizer = FullTokenizer(VOCAB_PATH) | |
| def build_feed_dict(texts): | |
| text_features = list(convert_lst_to_features( | |
| texts, SEQ_LEN, SEQ_LEN, | |
| bert_tokenizer, log, False, False)) | |
| target_shape = (len(texts), -1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def build_input_fn(container): | |
| def gen(): | |
| while True: | |
| try: | |
| yield build_feed_dict(container.get()) | |
| except StopIteration: | |
| yield build_feed_dict(container.get()) | |
| def input_fn(): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def batch(iterable, n=1): | |
| l = len(iterable) | |
| for ndx in range(0, l, n): | |
| yield iterable[ndx:min(ndx + n, l)] | |
| def build_vectorizer(_estimator, _input_fn_builder, batch_size=128): | |
| container = DataContainer() | |
| predict_fn = _estimator.predict(_input_fn_builder(container), yield_single_examples=False) | |
| def vectorize(text, verbose=False): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| dim = 1024 | |
| graph = tf.Graph() | |
| sess = tf.InteractiveSession(graph=graph) | |
| Q = tf.placeholder("float", [dim]) | |
| S = tf.placeholder("float", [None, dim]) |