Last active
September 3, 2020 12:24
-
-
Save chiragjn/8dfd08078403d005fbd80150c3c3999a to your computer and use it in GitHub Desktop.
Small BERT checkpoints vs tfhub modules
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
mkdir -p small_bert_checkpoints | |
cd small_bert_checkpoints/ | |
wget https://storage.googleapis.com/bert_models/2020_02_20/all_bert_models.zip | |
unzip all_bert_models.zip | |
find . -name 'uncased*.zip' -exec sh -c 'unzip -d "${1%.*}" "$1"' _ {} \; | |
rm uncased*.zip | |
rm all_bert_models.zip |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
bert-tensorflow==1.0.1 | |
tensorflow==1.15.0 | |
tensorflow-estimator==1.15.1 | |
tensorflow-hub==0.8.0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
import glob | |
import numpy as np | |
import tensorflow as tf | |
import tensorflow_hub as hub | |
from bert.modeling import BertModel, BertConfig, get_assignment_map_from_checkpoint | |
from bert.run_classifier import InputExample, InputFeatures, convert_examples_to_features | |
from bert import tokenization | |
training = False | |
class AbstractBase(object): | |
def get_outputs(self, text_as, text_bs=None): | |
text_bs = text_bs or ([None] * len(text_as)) | |
ies = [InputExample(str(i), text_a, text_b) for i, (text_a, text_b) in enumerate(zip(text_as, text_bs))] | |
inp_fs = convert_examples_to_features(examples=ies, label_list=[None], | |
max_seq_length=self.max_seq_length, tokenizer=self.tok) | |
input_ids = [inp_f.input_ids for inp_f in inp_fs] | |
input_mask = [inp_f.input_mask for inp_f in inp_fs] | |
segment_ids = [inp_f.segment_ids for inp_f in inp_fs] | |
bert_inputs = { | |
self.input_ids: input_ids, | |
self.input_mask: input_mask, | |
self.segment_ids: segment_ids, | |
} | |
bo = self.sess.run(self.bert_outputs, feed_dict=bert_inputs) | |
sequence_output = bo['sequence_output'] | |
pooled_output = bo['pooled_output'] | |
return input_ids, input_mask, segment_ids, sequence_output, pooled_output | |
class TFHubSmallBERT(AbstractBase): | |
def __init__(self, handle, training=False, max_seq_length=512): | |
self.max_seq_length = max_seq_length | |
self.graph = tf.Graph() | |
with self.graph.as_default(): | |
self.bert_module = hub.Module(handle, trainable=False, tags={'train'} if training else None) | |
self.sess = tf.Session() | |
self.sess.run(tf.group(tf.global_variables_initializer(), tf.tables_initializer())) | |
tokenization_info = self.bert_module(signature='tokenization_info', as_dict=True) | |
vocab_file, do_lower_case = self.sess.run([tokenization_info['vocab_file'], tokenization_info['do_lower_case']]) | |
self.input_ids = tf.placeholder(tf.int32, shape=(None, self.max_seq_length)) | |
self.input_mask = tf.placeholder(tf.int32, shape=(None, self.max_seq_length)) | |
self.segment_ids = tf.placeholder(tf.int32, shape=(None, self.max_seq_length)) | |
bert_inputs = dict( | |
input_ids=self.input_ids, | |
input_mask=self.input_mask, | |
segment_ids=self.segment_ids, | |
) | |
self.bert_outputs = self.bert_module(bert_inputs, signature="tokens", as_dict=True) | |
self.tok = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case) | |
class CheckpointSmallBERT(AbstractBase): | |
def __init__(self, path, training=False, max_seq_length=512): | |
self.max_seq_length = max_seq_length | |
self.graph = tf.Graph() | |
with self.graph.as_default(): | |
self.input_ids = tf.placeholder(tf.int32, shape=(None, self.max_seq_length)) | |
self.input_mask = tf.placeholder(tf.int32, shape=(None, self.max_seq_length)) | |
self.segment_ids = tf.placeholder(tf.int32, shape=(None, self.max_seq_length)) | |
self.bert_config = BertConfig.from_json_file(path + '/bert_config.json') | |
self.bert_module = BertModel(config=self.bert_config, is_training=training, | |
input_ids=self.input_ids, input_mask=self.input_mask, | |
token_type_ids=self.segment_ids, use_one_hot_embeddings=False) | |
assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint( | |
tf.trainable_variables(), | |
path + '/bert_model.ckpt' | |
) | |
tf.train.init_from_checkpoint(path + '/bert_model.ckpt', assignment_map) | |
self.sess = tf.Session() | |
self.sess.run(tf.group(tf.global_variables_initializer(), tf.tables_initializer())) | |
self.bert_outputs = { | |
'sequence_output': self.bert_module.get_sequence_output(), | |
'pooled_output': self.bert_module.get_pooled_output(), | |
} | |
self.tok = tokenization.FullTokenizer(vocab_file=path + '/vocab.txt', do_lower_case=True) | |
def test(hub_handle, path): | |
print('=' * 120) | |
print(hub_handle, path) | |
text_a = ['well read students', 'this is a model compression test'] | |
text_b = ['learn better', 'all okay?'] | |
msl = json.load(open(path + '/bert_config.json'))['max_position_embeddings'] | |
checkpoint_model = CheckpointSmallBERT(path, training=False, max_seq_length=msl) | |
hub_model = TFHubSmallBERT(f'https://tfhub.dev/google/{hub_handle}/1', training=False, max_seq_length=msl) | |
chiids, chim, chsids, chso, chpo = checkpoint_model.get_outputs(text_a, text_b) | |
tfiids, tfim, tfsids, tfso, tfpo = hub_model.get_outputs(text_a, text_b) | |
assert np.allclose(chso, tfso, atol=0.05), np.max(np.abs(tfso - chso)) | |
assert np.allclose(chpo, tfpo, atol=0.05), np.max(np.abs(tfpo - chpo)) | |
# some force cleanup | |
del checkpoint_model | |
del hub_model | |
if __name__ == '__main__': | |
f = open('test_report.txt', 'w') | |
for chkpoint in glob.glob('small_bert_checkpoints/uncased_*'): | |
try: | |
hub_handle = 'small_bert/bert_' + chkpoint.split('/', 1)[-1] | |
print('Testing', hub_handle, 'with', chkpoint, file=f, flush=True) | |
test(hub_handle, chkpoint) | |
print('OK', file=f, flush=True) | |
except AssertionError as e: | |
print(e, file=f, flush=True) | |
f.close() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Testing small_bert/bert_uncased_L-8_H-768_A-12 with small_bert_checkpoints/uncased_L-8_H-768_A-12 | |
6.343616 | |
Testing small_bert/bert_uncased_L-2_H-128_A-2 with small_bert_checkpoints/uncased_L-2_H-128_A-2 | |
OK | |
Testing small_bert/bert_uncased_L-12_H-128_A-2 with small_bert_checkpoints/uncased_L-12_H-128_A-2 | |
OK | |
Testing small_bert/bert_uncased_L-4_H-768_A-12 with small_bert_checkpoints/uncased_L-4_H-768_A-12 | |
2.3570232 | |
Testing small_bert/bert_uncased_L-8_H-512_A-8 with small_bert_checkpoints/uncased_L-8_H-512_A-8 | |
OK | |
Testing small_bert/bert_uncased_L-6_H-128_A-2 with small_bert_checkpoints/uncased_L-6_H-128_A-2 | |
OK | |
Testing small_bert/bert_uncased_L-12_H-256_A-4 with small_bert_checkpoints/uncased_L-12_H-256_A-4 | |
OK | |
Testing small_bert/bert_uncased_L-6_H-768_A-12 with small_bert_checkpoints/uncased_L-6_H-768_A-12 | |
1.358148 | |
Testing small_bert/bert_uncased_L-12_H-512_A-8 with small_bert_checkpoints/uncased_L-12_H-512_A-8 | |
OK | |
Testing small_bert/bert_uncased_L-6_H-256_A-4 with small_bert_checkpoints/uncased_L-6_H-256_A-4 | |
OK | |
Testing small_bert/bert_uncased_L-8_H-128_A-2 with small_bert_checkpoints/uncased_L-8_H-128_A-2 | |
OK | |
Testing small_bert/bert_uncased_L-6_H-512_A-8 with small_bert_checkpoints/uncased_L-6_H-512_A-8 | |
OK | |
Testing small_bert/bert_uncased_L-10_H-128_A-2 with small_bert_checkpoints/uncased_L-10_H-128_A-2 | |
OK | |
Testing small_bert/bert_uncased_L-4_H-512_A-8 with small_bert_checkpoints/uncased_L-4_H-512_A-8 | |
OK | |
Testing small_bert/bert_uncased_L-10_H-512_A-8 with small_bert_checkpoints/uncased_L-10_H-512_A-8 | |
OK | |
Testing small_bert/bert_uncased_L-10_H-256_A-4 with small_bert_checkpoints/uncased_L-10_H-256_A-4 | |
OK | |
Testing small_bert/bert_uncased_L-8_H-256_A-4 with small_bert_checkpoints/uncased_L-8_H-256_A-4 | |
OK | |
Testing small_bert/bert_uncased_L-12_H-768_A-12 with small_bert_checkpoints/uncased_L-12_H-768_A-12 | |
6.4170265 | |
Testing small_bert/bert_uncased_L-2_H-768_A-12 with small_bert_checkpoints/uncased_L-2_H-768_A-12 | |
2.6249537 | |
Testing small_bert/bert_uncased_L-2_H-256_A-4 with small_bert_checkpoints/uncased_L-2_H-256_A-4 | |
OK | |
Testing small_bert/bert_uncased_L-4_H-128_A-2 with small_bert_checkpoints/uncased_L-4_H-128_A-2 | |
OK | |
Testing small_bert/bert_uncased_L-4_H-256_A-4 with small_bert_checkpoints/uncased_L-4_H-256_A-4 | |
OK | |
Testing small_bert/bert_uncased_L-10_H-768_A-12 with small_bert_checkpoints/uncased_L-10_H-768_A-12 | |
9.500095 | |
Testing small_bert/bert_uncased_L-2_H-512_A-8 with small_bert_checkpoints/uncased_L-2_H-512_A-8 | |
OK |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment