Skip to content

Instantly share code, notes, and snippets.

View gaphex's full-sized avatar

Denis gaphex

  • Moscow
View GitHub Profile
if BUCKET_NAME:
!gsutil -m cp -r $MODEL_DIR $PRETRAINING_DIR gs://$BUCKET_NAME
BUCKET_NAME = "bert_resourses" #@param {type:"string"}
MODEL_DIR = "bert_model" #@param {type:"string"}
PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}
VOC_FNAME = "vocab.txt" #@param {type:"string"}
# Input data pipeline config
TRAIN_BATCH_SIZE = 128 #@param {type:"integer"}
MAX_PREDICTIONS = 20 #@param {type:"integer"}
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param
model_fn = model_fn_builder(
bert_config=bert_config,
init_checkpoint=INIT_CHECKPOINT,
learning_rate=LEARNING_RATE,
num_train_steps=TRAIN_STEPS,
num_warmup_steps=10,
use_tpu=USE_TPU,
use_one_hot_embeddings=True)
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)
estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS)
@gaphex
gaphex / optimize_inference_graph.py
Last active June 30, 2019 10:54
Optimize and serialize BERT graph for inference.
import os
import tensorflow as tf
from bert_serving.server.graph import optimize_graph
from bert_serving.server.helper import get_args_parser
MODEL_DIR = '/content/wwm_uncased_L-24_H-1024_A-16/' #@param {type:"string"}
GRAPH_DIR = '/content/graph/' #@param {type:"string"}
GRAPH_OUT = 'extractor.pbtxt' #@param {type:"string"}
@gaphex
gaphex / build_estimator.py
Last active June 13, 2019 13:28
Building a tf.Estimator from serialized GrapfDef
def model_fn(features, mode):
with tf.gfile.GFile(GRAPH_PATH, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
output = tf.import_graph_def(graph_def,
input_map={k + ':0': features[k]
for k in INPUT_NAMES},
return_elements=['final_encodes:0'])
@gaphex
gaphex / buid_feed_dict.py
Last active June 9, 2019 11:08
List of strings to feed dict converter
INPUT_NAMES = ['input_ids', 'input_mask', 'input_type_ids']
bert_tokenizer = FullTokenizer(VOCAB_PATH)
def build_feed_dict(texts):
text_features = list(convert_lst_to_features(
texts, SEQ_LEN, SEQ_LEN,
bert_tokenizer, log, False, False))
target_shape = (len(texts), -1)
@gaphex
gaphex / input_fn_with_generator.py
Last active June 23, 2019 11:57
input fn with generator
def build_input_fn(container):
def gen():
while True:
try:
yield build_feed_dict(container.get())
except StopIteration:
yield build_feed_dict(container.get())
def input_fn():
@gaphex
gaphex / run_inference.py
Last active June 23, 2019 11:58
Extracting features from text
def batch(iterable, n=1):
l = len(iterable)
for ndx in range(0, l, n):
yield iterable[ndx:min(ndx + n, l)]
def build_vectorizer(_estimator, _input_fn_builder, batch_size=128):
container = DataContainer()
predict_fn = _estimator.predict(_input_fn_builder(container), yield_single_examples=False)
def vectorize(text, verbose=False):
@gaphex
gaphex / retriever_placeholders.py
Created June 13, 2019 15:04
create placeholders for KNN
dim = 1024
graph = tf.Graph()
sess = tf.InteractiveSession(graph=graph)
Q = tf.placeholder("float", [dim])
S = tf.placeholder("float", [None, dim])