Skip to content

Instantly share code, notes, and snippets.

@warmspringwinds
Created February 8, 2017 18:18
Show Gist options
  • Save warmspringwinds/954bfec06de518e10ddefd049470c5e5 to your computer and use it in GitHub Desktop.
Save warmspringwinds/954bfec06de518e10ddefd049470c5e5 to your computer and use it in GitHub Desktop.
%matplotlib inline
import tensorflow as tf
import numpy as np
import skimage.io as io
import os, sys
from PIL import Image
from matplotlib import pyplot as plt
sys.path.append("tf-image-segmentation/")
sys.path.append("/home/dpakhom1/workspace/my_models/slim/")
checkpoints_dir = '/home/dpakhom1/checkpoints'
#fcn_16s_checkpoint_path = '/home/dpakhom1/tf_projects/segmentation/model_fcn8s_final.ckpt'
resnet_101_v1_checkpoint_path = '/home/dpakhom1/tf_projects/segmentation/model_resnet_101_8s.ckpt'
os.environ["CUDA_VISIBLE_DEVICES"] = '2'
slim = tf.contrib.slim
from tf_image_segmentation.models.fcn_8s import FCN_8s
from tf_image_segmentation.models.resnet_v1_101_8s import resnet_v1_101_8s, extract_resnet_v1_101_mapping_without_logits
from matplotlib import pyplot as plt
from tf_image_segmentation.utils.pascal_voc import pascal_segmentation_lut
from tf_image_segmentation.utils.tf_records import read_tfrecord_and_decode_into_image_annotation_pair_tensors
from tf_image_segmentation.utils.inference import adapt_network_for_any_size_input
from tf_image_segmentation.utils.visualization import visualize_segmentation_adaptive
pascal_voc_lut = pascal_segmentation_lut()
tfrecord_filename = 'pascal_augmented_val.tfrecords'
number_of_classes = 21
filename_queue = tf.train.string_input_producer(
[tfrecord_filename], num_epochs=1)
image, annotation = read_tfrecord_and_decode_into_image_annotation_pair_tensors(filename_queue)
# Fake batch for image and annotation by adding
# leading empty axis.
image_batch_tensor = tf.expand_dims(image, axis=0)
annotation_batch_tensor = tf.expand_dims(annotation, axis=0)
# Be careful: after adaptation, network returns final labels
# and not logits
resnet_v1_101_8s = adapt_network_for_any_size_input(resnet_v1_101_8s, 8)
pred, fcn_8s_variables_mapping = resnet_v1_101_8s(image_batch_tensor=image_batch_tensor,
number_of_classes=number_of_classes,
is_training=False)
# Take away the masked out values from evaluation
weights = tf.to_float( tf.not_equal(annotation_batch_tensor, 255) )
# Define the accuracy metric: Mean Intersection Over Union
miou, update_op = slim.metrics.streaming_mean_iou(predictions=pred,
labels=annotation_batch_tensor,
num_classes=number_of_classes,
weights=weights)
# The op for initializing the variables.
initializer = tf.local_variables_initializer()
saver = tf.train.Saver()
means = slim.get_variables_by_name('moving_mean')
variances = slim.get_variables_by_name('moving_variance')
with tf.Session() as sess:
sess.run(initializer)
saver.restore(sess, resnet_101_v1_checkpoint_path)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
result = sess.run(means)
coord.request_stop()
coord.join(threads)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment