This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
MODEL_DIR = "uncased_L-12_H-768_A-12" | |
config_path = "/content/{}/bert_config.json".format(MODEL_DIR) | |
vocab_path = "/content/{}/vocab.txt".format(MODEL_DIR) | |
tags_and_args = [] | |
for is_training in (True, False): | |
tags = set() | |
if is_training: | |
tags.add("train") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def read_examples(str_list): | |
"""Read a list of `InputExample`s from a list of strings.""" | |
unique_id = 0 | |
for s in str_list: | |
line = convert_to_unicode(s) | |
if not line: | |
continue | |
line = line.strip() | |
text_a = None | |
text_b = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def features_to_arrays(features): | |
"""Convert a list of InputFeatures to np.arrays""" | |
all_input_ids = [] | |
all_input_mask = [] | |
all_segment_ids = [] | |
for feature in features: | |
all_input_ids.append(feature.input_ids) | |
all_input_mask.append(feature.input_mask) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def build_preprocessor(voc_path, seq_len, lower=True): | |
""" | |
Build a text preprocessing pipeline for BERT | |
Returns a function which converts a list of strings to a list | |
of three np.arrays with [input_ids, input_mask, segment_ids] | |
""" | |
tokenizer = FullTokenizer(vocab_file=voc_path, do_lower_case=lower) | |
def strings_to_arrays(sents): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BertLayer(tf.keras.layers.Layer): | |
def __init__(self, bert_path, seq_len=64, n_tune_layers=3, | |
pooling="cls", verbose=False, | |
tune_embeddings=False, **kwargs): | |
self.n_tune_layers = n_tune_layers | |
self.tune_embeddings = tune_embeddings | |
self.seq_len = seq_len | |
self.trainable = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def build(self, input_shape): | |
self.bert = hub.Module(self.bert_path, trainable=self.trainable, name=f"{self.name}_module") | |
trainable_layers = [] | |
if self.tune_embeddings: | |
trainable_layers.append("embeddings") | |
if self.pooling == "cls": | |
trainable_layers.append("pooler") | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def build_preprocessor(self): | |
sess = tf.keras.backend.get_session() | |
tokenization_info = self.bert(signature="tokenization_info", as_dict=True) | |
vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"], | |
tokenization_info["do_lower_case"]]) | |
self.preprocessor = build_preprocessor(vocab_file, self.seq_len, do_lower_case) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def initialize_module(self): | |
sess = tf.keras.backend.get_session() | |
vars_initialized = sess.run([tf.is_variable_initialized(var) | |
for var in self.bert.variables]) | |
uninitialized = [] | |
for var, is_initialized in zip(self.bert.variables, vars_initialized): | |
if not is_initialized: | |
uninitialized.append(var) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def call(self, input): | |
if self.w_preprocessing: | |
input = tf.numpy_function(self.preprocessor, [input], [tf.int32, tf.int32, tf.int32]) | |
for feature in input: | |
feature.set_shape((None, self.seq_len)) | |
input_ids, input_mask, segment_ids = input | |
bert_inputs = dict( | |
input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
inp = tf.keras.Input(shape=(1,), dtype=tf.string) | |
encoder = BertLayer(bert_path="./bert-module/", seq_len=48, | |
tune_embeddings=False, do_preprocessing=True, | |
pooling='cls', n_tune_layers=3, verbose=False) | |
pred = tf.keras.layers.Dense(1, activation='sigmoid')(encoder(inp)) | |
model = tf.keras.models.Model(inputs=[inp], outputs=[pred]) | |
model.summary() |