Created
February 3, 2017 14:38
-
-
Save stefanthaler/7240f62d78de0b1a34ad2029e3d2336b to your computer and use it in GitHub Desktop.
A simple example to demonstrate how to link embedding metadata to word embeddings in tensorflow / tensorboard
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Simple example to demostrate the embedding visualization for word embeddings in tensorflow / tensorboard | |
https://www.tensorflow.org/how_tos/embedding_viz/ | |
""" | |
import tensorflow as tf | |
import os | |
assert tf.__version__ == '1.0.0-rc0' # if code breaks, check tensorflow version | |
from tensorflow.contrib.tensorboard.plugins import projector | |
""" | |
Hyperparameter | |
""" | |
checkpoint_path = "checkpoints" | |
if not os.path.exists(checkpoint_path): os.mkdir(checkpoint_path) | |
vocabulary_size = 20 # we have 20 words in our vocabulary | |
word_embedding_dim = 10 # each word is represented by a [1,10] dimensional row vector. | |
""" | |
Create Word Embedding | |
""" | |
# create word embeddings, fill randomly | |
word_embeddings = tf.Variable(tf.random_uniform([vocabulary_size, word_embedding_dim], -1.0, 1.0), name='word_embeddings') | |
""" | |
Save Word embedding checkpoint | |
""" | |
# Saver | |
saver = tf.train.Saver(tf.global_variables()) | |
# Start session | |
session = tf.Session() | |
summary_writer = tf.summary.FileWriter(checkpoint_path, graph=session.graph) | |
session.run([tf.global_variables_initializer()]) # init variables | |
#... do stuff with session | |
# save checkpoints periodically | |
chkpoint_out_filename = os.path.join(checkpoint_path, "word_embedding_sample") | |
saver.save(session, chkpoint_out_filename, global_step=1) | |
print("\nword_embeddings checkpoint saved") | |
""" | |
Write metadata file | |
""" | |
tsv_row_template = "{}\t{}\t{}\n" | |
with open(os.path.join(checkpoint_path, 'word_embeddings.tsv'), "w") as f: | |
header_row = tsv_row_template.format("Name", "Category", "Type") | |
f.write(header_row) | |
for w_id in xrange(vocabulary_size): | |
# get metadat for each word | |
word = "word %0.2d"%w_id | |
category = w_id%5 | |
word_type = "type %i"%(w_id%3) | |
data_row = tsv_row_template.format(word,category,word_type) | |
f.write(data_row) | |
print("word_embeddings.tsv written.") | |
""" | |
Link metadata tsv file to embedding | |
""" | |
config = projector.ProjectorConfig() | |
embedding = config.embeddings.add() # could add more metadata files here | |
embedding.tensor_name = word_embeddings.name | |
embedding.metadata_path = os.path.join(checkpoint_path, 'word_embeddings.tsv') | |
projector.visualize_embeddings(summary_writer, config) | |
print("Metadata linked to checkpoint\n") | |
print("run: tensorboard --logdir checkpoints/") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment