Created
February 5, 2018 03:49
-
-
Save martinsotir/b2caa85dcd2993381cf03e9be358bd61 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Utility script to visualize embeddings using the tensorboard projector module. | |
Usage | |
----- | |
Dependencies : numpy, pillow, pandas, tensorflow | |
Call `prepare_projection(embedding, metadata, image_paths, ...)`, where : | |
- `embedding` is a 2D numpy array (`n_sample` x `dim_embedding`) | |
- `metadata` is a pandas Dataframe of length `n_sample`, containing sample descritive data (optional) | |
- `image_paths` is a list of sample image paths (optional) | |
Then start tensorboard: | |
```bash | |
tensorboard --logdir ./projector_log | |
``` | |
References | |
---------- | |
- TensorBoard: Embedding Visualization : https://www.tensorflow.org/versions/r0.12/how_tos/embedding_viz/ | |
- Tensorboard-own-image-data-image-features-embedding-visualization : https://github.com/anujshah1003/Tensorboard-own-image-data-image-features-embedding-visualization | |
""" | |
from typing import Optional, List | |
import os | |
import numpy as np | |
import pandas as pd | |
from PIL import Image | |
import tensorflow as tf | |
from tensorflow.contrib.tensorboard.plugins import projector | |
def create_sprite(paths, w, h): | |
imgs = np.stack([np.asarray(Image.open(p).resize((w, h), resample=Image.ANTIALIAS), dtype=np.uint8) for p in paths]) | |
N = int(np.ceil(np.sqrt(imgs.shape[0]))) | |
padded = np.pad(imgs, pad_width=[(0, N ** 2 - imgs.shape[0]), (0, 0), (0, 0), (0, 0)], mode='constant', constant_values=0) | |
return np.moveaxis(padded.reshape([N, N, *padded.shape[1:]]), 1, 2).reshape((N * h, N * w, 3)) | |
def prepare_projection(embedding : np.array, | |
metadata: Optional[pd.DataFrame]=None, | |
image_paths: Optional[List[str]]=None, | |
projection_dir:str='projector_log', | |
tensor_name:str="embedding", | |
sprite_w:int=128, sprite_h:int=128): | |
if metadata is not None: | |
assert len(metadata) == embedding.shape[0] | |
if image_paths is not None: | |
assert len(image_paths) == embedding.shape[0] | |
assert (sprite_w is not None) and (sprite_h is not None) | |
embedding_var = tf.Variable(tf.convert_to_tensor(embedding, np.float32), name=tensor_name) | |
with tf.Session() as sess: | |
saver = tf.train.Saver([embedding_var]) | |
sess.run(embedding_var.initializer) | |
saver.save(sess, os.path.join(projection_dir, 'embedding.ckpt')) | |
summary_writer = tf.summary.FileWriter(projection_dir) | |
config = projector.ProjectorConfig() | |
embedding_elem = config.embeddings.add() | |
embedding_elem.tensor_name = embedding_var.name | |
if image_paths is not None: | |
sprites = create_sprite(image_paths, sprite_w, sprite_h) | |
image_path = 'sprites.jpg' | |
Image.fromarray(sprites).save(os.path.join(projection_dir, image_path)) | |
embedding_elem.sprite.image_path = image_path | |
embedding_elem.sprite.single_image_dim.extend([sprite_h, sprite_w]) | |
if metadata is not None: | |
metadata_path = 'metadata.tsv' | |
metadata.to_csv( os.path.join(projection_dir, metadata_path), index=False, header=True, sep='\t') | |
embedding_elem.metadata_path = metadata_path | |
projector.visualize_embeddings(summary_writer, config) | |
################### | |
# Example (requires some .jpg images in the ./images directory) | |
################## | |
from glob import iglob | |
image_paths = list(iglob("./images/*.jpg")) | |
# Use pixel data (downscaled) as embedding | |
embedding = np.vstack([np.asarray(Image.open(p).resize((20, 20)), dtype=np.uint8).reshape(-1,) for p in image_paths]) | |
# Use image paths and mames for metadata | |
metadata = pd.DataFrame({'path': image_paths}) | |
metadata['label'] = metadata['path'].str.split('[/\\\\]').str[-1] | |
# Create projection tensorboard log | |
prepare_projection(embedding, metadata, image_paths, sprite_w=128, sprite_h=128) | |
# Done! (start with `tensorboard --logdir ./projector_log`) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment