Skip to content

Instantly share code, notes, and snippets.

@kukuruza
Last active May 19, 2017 21:52
Show Gist options
  • Save kukuruza/b107575fd7418be764b106953131d955 to your computer and use it in GitHub Desktop.
Save kukuruza/b107575fd7418be764b106953131d955 to your computer and use it in GitHub Desktop.
Example of generating embeddings for images from a saved model (tensorflow)
# imports and other stuff
# example constants and instances of classes
data = Data() # some data provider
model = Model() # some model
batch_size = 16
pretrained_model_path = 'model/my_model'
log_dir = 'logs'
num_batches = data.size() / batch_size
tfconfig = tf.ConfigProto()
# a numpy array for embeddings and a list for labels
features = np.zeros(shape=(num_batches*batch_size, 128), dtype=float)
labels = []
# compute embeddings batch by batch
with tf.Session(config=tfconfig) as sess:
tf.global_variables_initializer().run()
restorer = tf.train.Saver()
restorer.restore(sess, pretrained_model_path)
for step in range(num_batches):
batch_images, batch_labels = data.next()
labels += batch_labels
feed_dict = {model.images: batch_images}
features[step*batch_size : (step+1)*batch_size, :] = \
sess.run(model.features, feed_dict)
# write labels
metadata_path = os.path.join(log_dir, 'metadata.tsv')
with open(metadata_path, 'w') as f:
for label in labels:
f.write('%s\n' % label)
# write embeddings
with tf.Session(config=config) as sess:
config = projector.ProjectorConfig()
embedding = config.embeddings.add()
embedding.tensor_name = 'feature_embedding'
embedding.metadata_path = metadata_path
embedding_var = tf.Variable(features, name='feature_embedding')
sess.run(embedding_var.initializer)
projector.visualize_embeddings(tf.summary.FileWriter(log_dir), config)
saver = tf.train.Saver({"feature_embedding": embedding_var})
saver.save(sess, os.path.join(log_dir, 'model_features'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment