Skip to content

Instantly share code, notes, and snippets.

@hans
Last active May 18, 2019 18:04
Show Gist options
  • Save hans/ad1e66c679a8656c6356127de698855c to your computer and use it in GitHub Desktop.
Save hans/ad1e66c679a8656c6356127de698855c to your computer and use it in GitHub Desktop.
visualize a BERT masked LM example drawn from a tfrecord
def viz_example(ex):
toks = [v[idx] for idx in ex.features.feature["input_ids"].int64_list.value if idx > 0]
masked_words = [v[idx] for idx in ex.features.feature["masked_lm_ids"].int64_list.value if idx > 0]
masked_dict = dict(zip(ex.features.feature["masked_lm_positions"].int64_list.value, masked_words))
toks_out = ["%10s" % tok for tok in toks]
mask_out = ["%10s" % masked_dict[i] if i in masked_dict else " " * 10 for i in range(len(toks))]
for block_start in range(0, len(toks), 10):
print(" ".join(toks_out[block_start:block_start+10]))
print(" ".join(mask_out[block_start:block_start+10]))
print()
from pathlib import Path
import tensorflow as tf
with Path("~/om2/others/bert/uncased_L-12_H-768_A-12/vocab.txt").expanduser().open("r") as f:
v = [l.strip() for l in f if l.strip()]
for i, example in zip(range(5), tf.python_io.tf_record_iterator("books_full.train.scrambled.tfrecord")):
viz_example(tf.train.Example.FromString(example))
print("==========")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment