Created
September 3, 2016 03:04
-
-
Save anj1/f9874cef572030535c34cd7a0b2d52fa to your computer and use it in GitHub Desktop.
inception_heatmap.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from scipy import misc | |
from scipy.ndimage import filters | |
import numpy as np | |
import tensorflow as tf | |
#from inception import image_processing | |
from inception import inception_model as inception | |
#from PIL import Image | |
#from PIL import ImageFile | |
import os | |
FLAGS = tf.app.flags.FLAGS | |
tf.app.flags.DEFINE_string('image_file', '', | |
"""Image to segment.""") | |
tf.app.flags.DEFINE_string('eval_dir', '/tmp/imagenet_eval', | |
"""Directory where to write event logs.""") | |
tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/imagenet_train', | |
"""Directory where to read model checkpoints.""") | |
# def _eval_once(saver, summary_writer, top_1_op, top_5_op, summary_op): | |
# """Runs Eval once. | |
# | |
# Args: | |
# saver: Saver. | |
# summary_writer: Summary writer. | |
# top_1_op: Top 1 op. | |
# top_5_op: Top 5 op. | |
# summary_op: Summary op. | |
# """ | |
# with tf.Session() as sess: | |
# ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) | |
# if ckpt and ckpt.model_checkpoint_path: | |
# if os.path.isabs(ckpt.model_checkpoint_path): | |
# # Restores from checkpoint with absolute path. | |
# saver.restore(sess, ckpt.model_checkpoint_path) | |
# else: | |
# # Restores from checkpoint with relative path. | |
# saver.restore(sess, os.path.join(FLAGS.checkpoint_dir, | |
# ckpt.model_checkpoint_path)) | |
# | |
# # Assuming model_checkpoint_path looks something like: | |
# # /my-favorite-path/imagenet_train/model.ckpt-0, | |
# # extract global_step from it. | |
# global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] | |
# print('Succesfully loaded model from %s at step=%s.' % | |
# (ckpt.model_checkpoint_path, global_step)) | |
# else: | |
# print('No checkpoint file found') | |
# return | |
# | |
# # Start the queue runners. | |
# coord = tf.train.Coordinator() | |
# try: | |
# threads = [] | |
# for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): | |
# threads.extend(qr.create_threads(sess, coord=coord, daemon=True, | |
# start=True)) | |
# | |
# num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size)) | |
# # Counts the number of correct predictions. | |
# count_top_1 = 0.0 | |
# count_top_5 = 0.0 | |
# total_sample_count = num_iter * FLAGS.batch_size | |
# step = 0 | |
# | |
# print('%s: starting evaluation on (%s).' % (datetime.now(), FLAGS.subset)) | |
# start_time = time.time() | |
# while step < num_iter and not coord.should_stop(): | |
# top_1, top_5 = sess.run([top_1_op, top_5_op]) | |
# count_top_1 += np.sum(top_1) | |
# count_top_5 += np.sum(top_5) | |
# step += 1 | |
# if step % 20 == 0: | |
# duration = time.time() - start_time | |
# sec_per_batch = duration / 20.0 | |
# examples_per_sec = FLAGS.batch_size / sec_per_batch | |
# print('%s: [%d batches out of %d] (%.1f examples/sec; %.3f' | |
# 'sec/batch)' % (datetime.now(), step, num_iter, | |
# examples_per_sec, sec_per_batch)) | |
# start_time = time.time() | |
# | |
# # Compute precision @ 1. | |
# precision_at_1 = count_top_1 / total_sample_count | |
# recall_at_5 = count_top_5 / total_sample_count | |
# print('%s: precision @ 1 = %.4f recall @ 5 = %.4f [%d examples]' % | |
# (datetime.now(), precision_at_1, recall_at_5, total_sample_count)) | |
# | |
# summary = tf.Summary() | |
# summary.ParseFromString(sess.run(summary_op)) | |
# summary.value.add(tag='Precision @ 1', simple_value=precision_at_1) | |
# summary.value.add(tag='Recall @ 5', simple_value=recall_at_5) | |
# summary_writer.add_summary(summary, global_step) | |
# | |
# except Exception as e: # pylint: disable=broad-except | |
# coord.request_stop(e) | |
# | |
# coord.request_stop() | |
# coord.join(threads, stop_grace_period_secs=10) | |
def evaluate(): | |
#ImageFile.LOAD_TRUNCATED_IMAGES = True | |
resize_factor = 3 | |
#im = Image.open(FLAGS.image_file) | |
im = misc.imread(FLAGS.image_file) | |
im = misc.imresize(im, float(resize_factor), interp='bilinear') | |
print FLAGS.image_file | |
w = im.shape[0] | |
h = im.shape[1] | |
a = im # np.array(im.getdata()).reshape(im.size[0], im.size[1], 3) | |
print a.shape | |
#a = np.pad(a,((150,150),(150,150),(0,0)),mode='constant',constant_values=[0]) | |
num_classes = 1001 | |
"""Runs Eval once. | |
Args: | |
saver: Saver. | |
summary_writer: Summary writer. | |
top_1_op: Top 1 op. | |
top_5_op: Top 5 op. | |
summary_op: Summary op. | |
""" | |
with tf.Graph().as_default(): | |
images = tf.placeholder(dtype=tf.float32,shape=(1, 299, 299, 3)) | |
logits, _ = inception.inference(images, num_classes) | |
# Restore the moving average version of the learned variables for eval. | |
with tf.Session() as sess: | |
variable_averages = tf.train.ExponentialMovingAverage( | |
inception.MOVING_AVERAGE_DECAY) | |
variables_to_restore = variable_averages.variables_to_restore() | |
saver = tf.train.Saver() #vagriables_to_restore) | |
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) | |
if ckpt and ckpt.model_checkpoint_path: | |
if os.path.isabs(ckpt.model_checkpoint_path): | |
# Restores from checkpoint with absolute path. | |
saver.restore(sess, ckpt.model_checkpoint_path) | |
else: | |
# Restores from checkpoint with relative path. | |
saver.restore(sess, os.path.join(FLAGS.checkpoint_dir,ckpt.model_checkpoint_path)) | |
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] | |
print('Succesfully loaded model from %s at step=%s.' % (ckpt.model_checkpoint_path, global_step)) | |
else: | |
print('No checkpoint file found') | |
return | |
skip=4 | |
hm_w = int(np.ceil((w-300)/float(skip*resize_factor))) | |
hm_h = int(np.ceil((h-300)/float(skip*resize_factor))) | |
print (hm_w,hm_h) | |
htmap = np.zeros((hm_w,hm_h,num_classes),dtype=np.float32) | |
for i in range(0,hm_w): | |
for j in range(0,hm_h): | |
i2 = i*skip*resize_factor | |
j2 = j*skip*resize_factor | |
print (i2, j2) | |
crp = a[i2:i2+299, j2:j2+299, :].reshape(1, 299, 299, 3) | |
crp = (crp.astype('float32')/128.0) - 1 | |
#logits, _ = inception.inference(imgs, num_classes) | |
htmap[i,j,:] = sess.run(logits, feed_dict={images:crp}) | |
#print crp[0:4,0:4,:] | |
#print np.argmax(l) | |
for i in range(0,num_classes): | |
htmap[:,:,i] = filters.gaussian_filter(htmap[:,:,i], 32/skip, 0) | |
for i in range(0,hm_w): | |
for j in range(0,hm_h): | |
print np.argmax(htmap[i,j,:]) | |
# | |
# # Restore the moving average version of the learned variables for eval. | |
# variable_averages = tf.train.ExponentialMovingAverage( | |
# inception.MOVING_AVERAGE_DECAY) | |
# variables_to_restore = variable_averages.variables_to_restore() | |
# saver = tf.train.Saver(variables_to_restore) | |
# | |
# # Build the summary operation based on the TF collection of Summaries. | |
# summary_op = tf.merge_all_summaries() | |
# | |
# graph_def = tf.get_default_graph().as_graph_def() | |
# summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, | |
# graph_def=graph_def) | |
# | |
# while True: | |
# _eval_once(saver, summary_writer, top_1_op, top_5_op, summary_op) | |
# if FLAGS.run_once: | |
# break | |
# time.sleep(FLAGS.eval_interval_secs) | |
def main(unused_argv=None): | |
evaluate() | |
if __name__ == '__main__': | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment