Skip to content

Instantly share code, notes, and snippets.

@gokart23
Last active May 21, 2020 01:29
Show Gist options
  • Save gokart23/d42129fccf1893aa5ead4ce27a4c35dc to your computer and use it in GitHub Desktop.
Save gokart23/d42129fccf1893aa5ead4ce27a4c35dc to your computer and use it in GitHub Desktop.
Computation graph operations for tokenizing, masking each token, and adding BOS/EOS for processing by BERT
import tensorflow as tf
import tensorflow_text as text
tf.enable_eager_execution()
CLS_TOKEN, SEP_TOKEN, MASK_TOKEN = 101, 102, 103
def merge_dims(rt, axis=0):
to_expand = rt.nested_row_lengths()[axis]
to_elim = rt.nested_row_lengths()[axis + 1]
bar = tf.RaggedTensor.from_row_lengths(to_elim, row_lengths=to_expand)
new_row_lengths = tf.reduce_sum(bar, axis=1)
return tf.RaggedTensor.from_nested_row_lengths(rt.flat_values, rt.nested_row_lengths()[:axis] + (new_row_lengths,)), new_row_lengths
def get_each_token_masked(sample, tokenizer):
# Tokenize, and convert from ragged tensors to full tensors
ragged_tokens = tokenizer.tokenize(sample)
merged_tokens, seq_lengths = merge_dims(ragged_tokens, axis=0)
merged_tokens = tf.cast(merged_tokens.to_tensor(), dtype=tf.int32)
# Generate indices for refering to correct positions to mask
batch_size = tf.shape(merged_tokens)[0]
max_length = tf.shape(merged_tokens)[1]
new_batch_size = batch_size * max_length
bsz_range = tf.range(batch_size)
seql_range = tf.range(max_length)
mask_fill = tf.fill([new_batch_size], MASK_TOKEN)
bsz_word_idx = tf.reshape(tf.stack(tf.meshgrid(bsz_range, seql_range, indexing='ij'), axis=-1), [-1, 2])
bsz_word_mask_idx = tf.concat([bsz_word_idx, tf.expand_dims(bsz_word_idx[:,1], axis=-1)], axis=-1)
# Repeat along seq_length, to get a mask_tok_idx dimension, and replace work token with mask along diagonal
repeated_toks = tf.repeat(merged_tokens, max_length, axis=0)
repeated_seq_lengths = tf.repeat(seq_lengths, max_length, axis=0)
repeated_toks = tf.reshape(repeated_toks, [batch_size, max_length, max_length])
real_tokens = tf.gather_nd(repeated_toks, bsz_word_mask_idx)
masked_tokens = tf.tensor_scatter_nd_update(repeated_toks, bsz_word_mask_idx, mask_fill)
masked_tokens = tf.reshape(masked_tokens, [new_batch_size, max_length])
# Remove PAD predictions - assumes PAD is 0
non_pad = tf.reshape(tf.where(real_tokens), [-1])
pred_tokens = tf.gather(real_tokens, non_pad)
pred_masked_tokens = tf.gather(masked_tokens, non_pad)
pred_seq_lengths = tf.gather(repeated_seq_lengths, non_pad)
pred_nd_idx = tf.gather(bsz_word_mask_idx, non_pad)[:, -1]
bsz = tf.shape(pred_masked_tokens)[0]
# Add CLS, SEP to beginning, end
sep_tokens = tf.fill([bsz, 1], SEP_TOKEN)
cls_tokens = tf.fill([bsz, 1], CLS_TOKEN)
sep_added_tokens = tf.concat([sep_tokens, tf.reverse_sequence(pred_masked_tokens, pred_seq_lengths, batch_axis=0, seq_axis=1)], axis=1)
cls_sep_added_tokens = tf.concat([cls_tokens, tf.reverse_sequence(sep_added_tokens, pred_seq_lengths+1, batch_axis=0, seq_axis=1)], axis=1)
# Setup prediction indexes of the mask tokens for use with gather_nd
pred_nd_idx = 1 + pred_nd_idx # account for [CLS]
pred_nd_idx = tf.transpose(tf.stack([tf.range(bsz), pred_nd_idx], axis=0))
return cls_sep_added_tokens, pred_tokens, (pred_seq_lengths+2), pred_nd_idx
with open('pretrained/vocab.txt', 'r') as fin:
words = [x.strip() for x in fin.readlines()]
init_table = tf.lookup.KeyValueTensorInitializer(words,
tf.range(tf.size(words, out_type=tf.int64),
dtype=tf.int64),
key_dtype=tf.string,
value_dtype=tf.int64)
table = tf.lookup.StaticVocabularyTable(init_table, 1, lookup_key_dtype=tf.string)
tokenizer = text.BertTokenizer(table, lower_case=True)
sample = ["this one", "and this whun"]
tokens, preds, seq_lengths, pred_nd_idx = get_each_token_masked(sample, tokenizer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment