Last active
May 19, 2017 21:52
-
-
Save kukuruza/b107575fd7418be764b106953131d955 to your computer and use it in GitHub Desktop.
Example of generating embeddings for images from a saved model (tensorflow)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# imports and other stuff | |
# example constants and instances of classes | |
data = Data() # some data provider | |
model = Model() # some model | |
batch_size = 16 | |
pretrained_model_path = 'model/my_model' | |
log_dir = 'logs' | |
num_batches = data.size() / batch_size | |
tfconfig = tf.ConfigProto() | |
# a numpy array for embeddings and a list for labels | |
features = np.zeros(shape=(num_batches*batch_size, 128), dtype=float) | |
labels = [] | |
# compute embeddings batch by batch | |
with tf.Session(config=tfconfig) as sess: | |
tf.global_variables_initializer().run() | |
restorer = tf.train.Saver() | |
restorer.restore(sess, pretrained_model_path) | |
for step in range(num_batches): | |
batch_images, batch_labels = data.next() | |
labels += batch_labels | |
feed_dict = {model.images: batch_images} | |
features[step*batch_size : (step+1)*batch_size, :] = \ | |
sess.run(model.features, feed_dict) | |
# write labels | |
metadata_path = os.path.join(log_dir, 'metadata.tsv') | |
with open(metadata_path, 'w') as f: | |
for label in labels: | |
f.write('%s\n' % label) | |
# write embeddings | |
with tf.Session(config=config) as sess: | |
config = projector.ProjectorConfig() | |
embedding = config.embeddings.add() | |
embedding.tensor_name = 'feature_embedding' | |
embedding.metadata_path = metadata_path | |
embedding_var = tf.Variable(features, name='feature_embedding') | |
sess.run(embedding_var.initializer) | |
projector.visualize_embeddings(tf.summary.FileWriter(log_dir), config) | |
saver = tf.train.Saver({"feature_embedding": embedding_var}) | |
saver.save(sess, os.path.join(log_dir, 'model_features')) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment