Created
March 21, 2018 13:35
-
-
Save chiragjn/69bdf49e0268b2f708821c2a1f2cca18 to your computer and use it in GitHub Desktop.
Put your vectors onto tensorboard!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# encoding: utf-8 | |
""" | |
Original credits | |
@author: BrikerMan | |
@contact: [email protected] | |
@blog: https://eliyar.biz | |
@version: 1.0 | |
@license: Apache Licence | |
@file: w2v_visualizer.py | |
@time: 2017/7/30 上午9:37 | |
""" | |
import sys | |
import os | |
# from gensim.models import Word2Vec | |
import tensorflow as tf | |
import numpy as np | |
import pickle | |
from tensorflow.contrib.tensorboard.plugins import projector | |
def visualize(arr, labels, output_path): | |
meta_file = "vis_metadata.tsv" | |
placeholder = arr | |
print placeholder.shape | |
with open(os.path.join(output_path, meta_file), 'wb') as file_metadata: | |
for i in xrange(arr.shape[0]): | |
# temporary solution for https://github.com/tensorflow/tensorflow/issues/9094 | |
word = labels[i] | |
if word == '': | |
print("Emply Line, should replecaed by any thing else, or will cause a bug of tensorboard") | |
file_metadata.write("{0}".format('<Empty Line>').encode('utf-8') + b'\n') | |
else: | |
file_metadata.write("{0}".format(word).encode('utf-8') + b'\n') | |
# define the model without training | |
sess = tf.InteractiveSession() | |
embedding = tf.Variable(placeholder, trainable=False, name='vis_metadata') | |
tf.global_variables_initializer().run() | |
saver = tf.train.Saver() | |
writer = tf.summary.FileWriter(output_path, sess.graph) | |
# adding into projector | |
config = projector.ProjectorConfig() | |
embed = config.embeddings.add() | |
embed.tensor_name = 'vis_metadata' | |
embed.metadata_path = meta_file | |
# Specify the width and height of a single thumbnail. | |
projector.visualize_embeddings(writer, config) | |
saver.save(sess, os.path.join(output_path, 'vec_metadata.ckpt')) | |
print('Run `tensorboard --logdir={0}` to run visualize result on tensorboard'.format(output_path)) | |
if __name__ == "__main__": | |
""" | |
Just run `python w2v_visualizer.py word2vec.model visualize_result` | |
""" | |
try: | |
model_path = sys.argv[1] | |
output_path = sys.argv[2] | |
except: | |
print("Please provice model path and output path") | |
model = np.load(model_path + '.npy') | |
labels = pickle.load(open(model_path + '_labels.pkl', 'rb')) | |
visualize(model, labels, output_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Prerequisites:
2D numpy array and pickle file of a list of corresponding labels
Usage:
python vecvis.py <numpy_array_filename_without_extension> <log_folder>
if 2D numpy array is called mymodel.npy, labels pickle file must be named mymodel_labels.pkl
Run
python vecvis.py mymodel tensorboard_logs