Last active
May 21, 2020 01:29
-
-
Save gokart23/d42129fccf1893aa5ead4ce27a4c35dc to your computer and use it in GitHub Desktop.
Computation graph operations for tokenizing, masking each token, and adding BOS/EOS for processing by BERT
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import tensorflow_text as text | |
tf.enable_eager_execution() | |
CLS_TOKEN, SEP_TOKEN, MASK_TOKEN = 101, 102, 103 | |
def merge_dims(rt, axis=0): | |
to_expand = rt.nested_row_lengths()[axis] | |
to_elim = rt.nested_row_lengths()[axis + 1] | |
bar = tf.RaggedTensor.from_row_lengths(to_elim, row_lengths=to_expand) | |
new_row_lengths = tf.reduce_sum(bar, axis=1) | |
return tf.RaggedTensor.from_nested_row_lengths(rt.flat_values, rt.nested_row_lengths()[:axis] + (new_row_lengths,)), new_row_lengths | |
def get_each_token_masked(sample, tokenizer): | |
# Tokenize, and convert from ragged tensors to full tensors | |
ragged_tokens = tokenizer.tokenize(sample) | |
merged_tokens, seq_lengths = merge_dims(ragged_tokens, axis=0) | |
merged_tokens = tf.cast(merged_tokens.to_tensor(), dtype=tf.int32) | |
# Generate indices for refering to correct positions to mask | |
batch_size = tf.shape(merged_tokens)[0] | |
max_length = tf.shape(merged_tokens)[1] | |
new_batch_size = batch_size * max_length | |
bsz_range = tf.range(batch_size) | |
seql_range = tf.range(max_length) | |
mask_fill = tf.fill([new_batch_size], MASK_TOKEN) | |
bsz_word_idx = tf.reshape(tf.stack(tf.meshgrid(bsz_range, seql_range, indexing='ij'), axis=-1), [-1, 2]) | |
bsz_word_mask_idx = tf.concat([bsz_word_idx, tf.expand_dims(bsz_word_idx[:,1], axis=-1)], axis=-1) | |
# Repeat along seq_length, to get a mask_tok_idx dimension, and replace work token with mask along diagonal | |
repeated_toks = tf.repeat(merged_tokens, max_length, axis=0) | |
repeated_seq_lengths = tf.repeat(seq_lengths, max_length, axis=0) | |
repeated_toks = tf.reshape(repeated_toks, [batch_size, max_length, max_length]) | |
real_tokens = tf.gather_nd(repeated_toks, bsz_word_mask_idx) | |
masked_tokens = tf.tensor_scatter_nd_update(repeated_toks, bsz_word_mask_idx, mask_fill) | |
masked_tokens = tf.reshape(masked_tokens, [new_batch_size, max_length]) | |
# Remove PAD predictions - assumes PAD is 0 | |
non_pad = tf.reshape(tf.where(real_tokens), [-1]) | |
pred_tokens = tf.gather(real_tokens, non_pad) | |
pred_masked_tokens = tf.gather(masked_tokens, non_pad) | |
pred_seq_lengths = tf.gather(repeated_seq_lengths, non_pad) | |
pred_nd_idx = tf.gather(bsz_word_mask_idx, non_pad)[:, -1] | |
bsz = tf.shape(pred_masked_tokens)[0] | |
# Add CLS, SEP to beginning, end | |
sep_tokens = tf.fill([bsz, 1], SEP_TOKEN) | |
cls_tokens = tf.fill([bsz, 1], CLS_TOKEN) | |
sep_added_tokens = tf.concat([sep_tokens, tf.reverse_sequence(pred_masked_tokens, pred_seq_lengths, batch_axis=0, seq_axis=1)], axis=1) | |
cls_sep_added_tokens = tf.concat([cls_tokens, tf.reverse_sequence(sep_added_tokens, pred_seq_lengths+1, batch_axis=0, seq_axis=1)], axis=1) | |
# Setup prediction indexes of the mask tokens for use with gather_nd | |
pred_nd_idx = 1 + pred_nd_idx # account for [CLS] | |
pred_nd_idx = tf.transpose(tf.stack([tf.range(bsz), pred_nd_idx], axis=0)) | |
return cls_sep_added_tokens, pred_tokens, (pred_seq_lengths+2), pred_nd_idx | |
with open('pretrained/vocab.txt', 'r') as fin: | |
words = [x.strip() for x in fin.readlines()] | |
init_table = tf.lookup.KeyValueTensorInitializer(words, | |
tf.range(tf.size(words, out_type=tf.int64), | |
dtype=tf.int64), | |
key_dtype=tf.string, | |
value_dtype=tf.int64) | |
table = tf.lookup.StaticVocabularyTable(init_table, 1, lookup_key_dtype=tf.string) | |
tokenizer = text.BertTokenizer(table, lower_case=True) | |
sample = ["this one", "and this whun"] | |
tokens, preds, seq_lengths, pred_nd_idx = get_each_token_masked(sample, tokenizer) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment