Created
February 8, 2017 18:18
-
-
Save warmspringwinds/954bfec06de518e10ddefd049470c5e5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%matplotlib inline | |
import tensorflow as tf | |
import numpy as np | |
import skimage.io as io | |
import os, sys | |
from PIL import Image | |
from matplotlib import pyplot as plt | |
sys.path.append("tf-image-segmentation/") | |
sys.path.append("/home/dpakhom1/workspace/my_models/slim/") | |
checkpoints_dir = '/home/dpakhom1/checkpoints' | |
#fcn_16s_checkpoint_path = '/home/dpakhom1/tf_projects/segmentation/model_fcn8s_final.ckpt' | |
resnet_101_v1_checkpoint_path = '/home/dpakhom1/tf_projects/segmentation/model_resnet_101_8s.ckpt' | |
os.environ["CUDA_VISIBLE_DEVICES"] = '2' | |
slim = tf.contrib.slim | |
from tf_image_segmentation.models.fcn_8s import FCN_8s | |
from tf_image_segmentation.models.resnet_v1_101_8s import resnet_v1_101_8s, extract_resnet_v1_101_mapping_without_logits | |
from matplotlib import pyplot as plt | |
from tf_image_segmentation.utils.pascal_voc import pascal_segmentation_lut | |
from tf_image_segmentation.utils.tf_records import read_tfrecord_and_decode_into_image_annotation_pair_tensors | |
from tf_image_segmentation.utils.inference import adapt_network_for_any_size_input | |
from tf_image_segmentation.utils.visualization import visualize_segmentation_adaptive | |
pascal_voc_lut = pascal_segmentation_lut() | |
tfrecord_filename = 'pascal_augmented_val.tfrecords' | |
number_of_classes = 21 | |
filename_queue = tf.train.string_input_producer( | |
[tfrecord_filename], num_epochs=1) | |
image, annotation = read_tfrecord_and_decode_into_image_annotation_pair_tensors(filename_queue) | |
# Fake batch for image and annotation by adding | |
# leading empty axis. | |
image_batch_tensor = tf.expand_dims(image, axis=0) | |
annotation_batch_tensor = tf.expand_dims(annotation, axis=0) | |
# Be careful: after adaptation, network returns final labels | |
# and not logits | |
resnet_v1_101_8s = adapt_network_for_any_size_input(resnet_v1_101_8s, 8) | |
pred, fcn_8s_variables_mapping = resnet_v1_101_8s(image_batch_tensor=image_batch_tensor, | |
number_of_classes=number_of_classes, | |
is_training=False) | |
# Take away the masked out values from evaluation | |
weights = tf.to_float( tf.not_equal(annotation_batch_tensor, 255) ) | |
# Define the accuracy metric: Mean Intersection Over Union | |
miou, update_op = slim.metrics.streaming_mean_iou(predictions=pred, | |
labels=annotation_batch_tensor, | |
num_classes=number_of_classes, | |
weights=weights) | |
# The op for initializing the variables. | |
initializer = tf.local_variables_initializer() | |
saver = tf.train.Saver() | |
means = slim.get_variables_by_name('moving_mean') | |
variances = slim.get_variables_by_name('moving_variance') | |
with tf.Session() as sess: | |
sess.run(initializer) | |
saver.restore(sess, resnet_101_v1_checkpoint_path) | |
coord = tf.train.Coordinator() | |
threads = tf.train.start_queue_runners(coord=coord) | |
result = sess.run(means) | |
coord.request_stop() | |
coord.join(threads) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment