Created
December 6, 2017 04:55
-
-
Save yaboo-oyabu/d92acc9f9263c92609d6c8d94e42c1d8 to your computer and use it in GitHub Desktop.
Vectorize image with VGG-16
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import numpy as np | |
import os | |
import sys | |
import tensorflow as tf | |
import json | |
import tensorflow.contrib.slim.python.slim.nets.vgg as vgg | |
from tensorflow.python.platform import gfile | |
flags = tf.app.flags | |
flags.DEFINE_string('output_json', None, 'specify output json file path') | |
flags.DEFINE_string('jpeg_pattern', None, 'specify target jpeg file pattern') | |
FLAGS = tf.app.flags.FLAGS | |
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" | |
slim = tf.contrib.slim | |
def main(): | |
output_json = FLAGS.output_json | |
target_jpeg = FLAGS.jpeg_pattern | |
tf.logging.set_verbosity(tf.logging.INFO) | |
tf.logging.info('TensorFlow version = {}'.format(tf.__version__)) | |
tf.logging.info('output_json={}'.format(output_json)) | |
tf.logging.info('target_jpeg={}'.format(target_jpeg)) | |
# https://s3.amazonaws.com/cadl/models/vgg16.tfmodel | |
model_file = os.path.join(os.path.dirname(__file__), 'vgg16.tfmodel') | |
graph_def = tf.GraphDef() | |
if os.path.exists(model_file): | |
with open(model_file, mode='rb') as f: | |
graph_def.ParseFromString(f.read()) | |
else: | |
assert(False) | |
image_size = vgg.vgg_16.default_image_size | |
tf.import_graph_def(graph_def, name='vgg16') | |
graph = tf.get_default_graph() | |
x = graph.get_tensor_by_name('vgg16/images:0') | |
row = graph.get_tensor_by_name('vgg16/fc8:0') | |
prob = graph.get_tensor_by_name('vgg16/prob:0') | |
jpeg = tf.placeholder(tf.string, shape=(), name="JPEG") | |
image = tf.image.decode_jpeg(tf.read_file(jpeg), channels=3) | |
image = tf.image.resize_images(image, (image_size, image_size)) | |
image = tf.cast(image, tf.float64) / 255.0 | |
with open(output_json, "w") as jfile: | |
with tf.Session() as sess: | |
files = gfile.Glob(target_jpeg) | |
group_by = 100 | |
groups = [files[i:i + group_by] for i in range(0, len(files), group_by)] | |
for fs in groups: | |
images = [] | |
for f in fs: | |
images.append(sess.run(image, feed_dict={jpeg: f})) | |
v1, v2 = sess.run([row, prob], feed_dict={ | |
x: images, | |
'vgg16/dropout_1/random_uniform:0': np.ones(shape=(group_by, 4096)), | |
'vgg16/dropout/random_uniform:0': np.ones(shape=(group_by, 4096))}) | |
for i in range(len(fs)): | |
jobj = { | |
"key": os.path.splitext(os.path.basename(fs[i]))[0], | |
"vector": v1[i].tolist(), | |
"prob": v2[i].tolist() | |
} | |
jfile.write(json.dumps(jobj) + "\n") | |
jfile.flush() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment