Skip to content

Instantly share code, notes, and snippets.

@yaboo-oyabu
Created December 6, 2017 04:55
Show Gist options
  • Save yaboo-oyabu/d92acc9f9263c92609d6c8d94e42c1d8 to your computer and use it in GitHub Desktop.
Save yaboo-oyabu/d92acc9f9263c92609d6c8d94e42c1d8 to your computer and use it in GitHub Desktop.
Vectorize image with VGG-16
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import sys
import tensorflow as tf
import json
import tensorflow.contrib.slim.python.slim.nets.vgg as vgg
from tensorflow.python.platform import gfile
flags = tf.app.flags
flags.DEFINE_string('output_json', None, 'specify output json file path')
flags.DEFINE_string('jpeg_pattern', None, 'specify target jpeg file pattern')
FLAGS = tf.app.flags.FLAGS
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
slim = tf.contrib.slim
def main():
output_json = FLAGS.output_json
target_jpeg = FLAGS.jpeg_pattern
tf.logging.set_verbosity(tf.logging.INFO)
tf.logging.info('TensorFlow version = {}'.format(tf.__version__))
tf.logging.info('output_json={}'.format(output_json))
tf.logging.info('target_jpeg={}'.format(target_jpeg))
# https://s3.amazonaws.com/cadl/models/vgg16.tfmodel
model_file = os.path.join(os.path.dirname(__file__), 'vgg16.tfmodel')
graph_def = tf.GraphDef()
if os.path.exists(model_file):
with open(model_file, mode='rb') as f:
graph_def.ParseFromString(f.read())
else:
assert(False)
image_size = vgg.vgg_16.default_image_size
tf.import_graph_def(graph_def, name='vgg16')
graph = tf.get_default_graph()
x = graph.get_tensor_by_name('vgg16/images:0')
row = graph.get_tensor_by_name('vgg16/fc8:0')
prob = graph.get_tensor_by_name('vgg16/prob:0')
jpeg = tf.placeholder(tf.string, shape=(), name="JPEG")
image = tf.image.decode_jpeg(tf.read_file(jpeg), channels=3)
image = tf.image.resize_images(image, (image_size, image_size))
image = tf.cast(image, tf.float64) / 255.0
with open(output_json, "w") as jfile:
with tf.Session() as sess:
files = gfile.Glob(target_jpeg)
group_by = 100
groups = [files[i:i + group_by] for i in range(0, len(files), group_by)]
for fs in groups:
images = []
for f in fs:
images.append(sess.run(image, feed_dict={jpeg: f}))
v1, v2 = sess.run([row, prob], feed_dict={
x: images,
'vgg16/dropout_1/random_uniform:0': np.ones(shape=(group_by, 4096)),
'vgg16/dropout/random_uniform:0': np.ones(shape=(group_by, 4096))})
for i in range(len(fs)):
jobj = {
"key": os.path.splitext(os.path.basename(fs[i]))[0],
"vector": v1[i].tolist(),
"prob": v2[i].tolist()
}
jfile.write(json.dumps(jobj) + "\n")
jfile.flush()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment