This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if BUCKET_NAME: | |
!gsutil -m cp -r $MODEL_DIR $PRETRAINING_DIR gs://$BUCKET_NAME |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
BUCKET_NAME = "bert_resourses" #@param {type:"string"} | |
MODEL_DIR = "bert_model" #@param {type:"string"} | |
PRETRAINING_DIR = "pretraining_data" #@param {type:"string"} | |
VOC_FNAME = "vocab.txt" #@param {type:"string"} | |
# Input data pipeline config | |
TRAIN_BATCH_SIZE = 128 #@param {type:"integer"} | |
MAX_PREDICTIONS = 20 #@param {type:"integer"} | |
MAX_SEQ_LENGTH = 128 #@param {type:"integer"} | |
MASKED_LM_PROB = 0.15 #@param |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model_fn = model_fn_builder( | |
bert_config=bert_config, | |
init_checkpoint=INIT_CHECKPOINT, | |
learning_rate=LEARNING_RATE, | |
num_train_steps=TRAIN_STEPS, | |
num_warmup_steps=10, | |
use_tpu=USE_TPU, | |
use_one_hot_embeddings=True) | |
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import tensorflow as tf | |
from bert_serving.server.graph import optimize_graph | |
from bert_serving.server.helper import get_args_parser | |
MODEL_DIR = '/content/wwm_uncased_L-24_H-1024_A-16/' #@param {type:"string"} | |
GRAPH_DIR = '/content/graph/' #@param {type:"string"} | |
GRAPH_OUT = 'extractor.pbtxt' #@param {type:"string"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def model_fn(features, mode): | |
with tf.gfile.GFile(GRAPH_PATH, 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
output = tf.import_graph_def(graph_def, | |
input_map={k + ':0': features[k] | |
for k in INPUT_NAMES}, | |
return_elements=['final_encodes:0']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
INPUT_NAMES = ['input_ids', 'input_mask', 'input_type_ids'] | |
bert_tokenizer = FullTokenizer(VOCAB_PATH) | |
def build_feed_dict(texts): | |
text_features = list(convert_lst_to_features( | |
texts, SEQ_LEN, SEQ_LEN, | |
bert_tokenizer, log, False, False)) | |
target_shape = (len(texts), -1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def build_input_fn(container): | |
def gen(): | |
while True: | |
try: | |
yield build_feed_dict(container.get()) | |
except StopIteration: | |
yield build_feed_dict(container.get()) | |
def input_fn(): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def batch(iterable, n=1): | |
l = len(iterable) | |
for ndx in range(0, l, n): | |
yield iterable[ndx:min(ndx + n, l)] | |
def build_vectorizer(_estimator, _input_fn_builder, batch_size=128): | |
container = DataContainer() | |
predict_fn = _estimator.predict(_input_fn_builder(container), yield_single_examples=False) | |
def vectorize(text, verbose=False): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
dim = 1024 | |
graph = tf.Graph() | |
sess = tf.InteractiveSession(graph=graph) | |
Q = tf.placeholder("float", [dim]) | |
S = tf.placeholder("float", [None, dim]) |