Last active
July 1, 2020 18:08
-
-
Save inyl/5679509f1211a80d0abea149338d49c8 to your computer and use it in GitHub Desktop.
tensor models im2txt python3 version
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package(default_visibility = [":internal"]) | |
licenses(["notice"]) # Apache 2.0 | |
exports_files(["LICENSE"]) | |
package_group( | |
name = "internal", | |
packages = [ | |
"//im2txt/...", | |
], | |
) | |
py_binary( | |
name = "build_mscoco_data", | |
default_python_version = "PY3", | |
srcs_version = "PY3", | |
srcs = [ | |
"data/build_mscoco_data.py", | |
], | |
) | |
sh_binary( | |
name = "download_and_preprocess_mscoco", | |
srcs = ["data/download_and_preprocess_mscoco.sh"], | |
data = [ | |
":build_mscoco_data", | |
], | |
) | |
py_library( | |
name = "configuration", | |
srcs = ["configuration.py"], | |
srcs_version = "PY2AND3", | |
) | |
py_library( | |
name = "show_and_tell_model", | |
srcs = ["show_and_tell_model.py"], | |
srcs_version = "PY2AND3", | |
deps = [ | |
"//im2txt/ops:image_embedding", | |
"//im2txt/ops:image_processing", | |
"//im2txt/ops:inputs", | |
], | |
) | |
py_test( | |
name = "show_and_tell_model_test", | |
size = "large", | |
srcs = ["show_and_tell_model_test.py"], | |
deps = [ | |
":configuration", | |
":show_and_tell_model", | |
], | |
) | |
py_library( | |
name = "inference_wrapper", | |
srcs = ["inference_wrapper.py"], | |
srcs_version = "PY2AND3", | |
deps = [ | |
":show_and_tell_model", | |
"//im2txt/inference_utils:inference_wrapper_base", | |
], | |
) | |
py_binary( | |
name = "train", | |
srcs = ["train.py"], | |
srcs_version = "PY3", | |
default_python_version = "PY3", | |
deps = [ | |
":configuration", | |
":show_and_tell_model", | |
], | |
) | |
py_binary( | |
name = "evaluate", | |
srcs = ["evaluate.py"], | |
srcs_version = "PY3", | |
default_python_version = "PY3", | |
deps = [ | |
":configuration", | |
":show_and_tell_model", | |
], | |
) | |
py_binary( | |
name = "run_inference", | |
srcs = ["run_inference.py"], | |
srcs_version = "PY3", | |
default_python_version = "PY3", | |
deps = [ | |
":configuration", | |
":inference_wrapper", | |
"//im2txt/inference_utils:caption_generator", | |
"//im2txt/inference_utils:vocabulary", | |
], | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Converts MSCOCO data to TFRecord file format with SequenceExample protos. | |
The MSCOCO images are expected to reside in JPEG files located in the following | |
directory structure: | |
train_image_dir/COCO_train2014_000000000151.jpg | |
train_image_dir/COCO_train2014_000000000260.jpg | |
... | |
and | |
val_image_dir/COCO_val2014_000000000042.jpg | |
val_image_dir/COCO_val2014_000000000073.jpg | |
... | |
The MSCOCO annotations JSON files are expected to reside in train_captions_file | |
and val_captions_file respectively. | |
This script converts the combined MSCOCO data into sharded data files consisting | |
of 256, 4 and 8 TFRecord files, respectively: | |
output_dir/train-00000-of-00256 | |
output_dir/train-00001-of-00256 | |
... | |
output_dir/train-00255-of-00256 | |
and | |
output_dir/val-00000-of-00004 | |
... | |
output_dir/val-00003-of-00004 | |
and | |
output_dir/test-00000-of-00008 | |
... | |
output_dir/test-00007-of-00008 | |
Each TFRecord file contains ~2300 records. Each record within the TFRecord file | |
is a serialized SequenceExample proto consisting of precisely one image-caption | |
pair. Note that each image has multiple captions (usually 5) and therefore each | |
image is replicated multiple times in the TFRecord files. | |
The SequenceExample proto contains the following fields: | |
context: | |
image/image_id: integer MSCOCO image identifier | |
image/data: string containing JPEG encoded image in RGB colorspace | |
feature_lists: | |
image/caption: list of strings containing the (tokenized) caption words | |
image/caption_ids: list of integer ids corresponding to the caption words | |
The captions are tokenized using the NLTK (http://www.nltk.org/) word tokenizer. | |
The vocabulary of word identifiers is constructed from the sorted list (by | |
descending frequency) of word tokens in the training set. Only tokens appearing | |
at least 4 times are considered; all other words get the "unknown" word id. | |
NOTE: This script will consume around 100GB of disk space because each image | |
in the MSCOCO dataset is replicated ~5 times (once per caption) in the output. | |
This is done for two reasons: | |
1. In order to better shuffle the training data. | |
2. It makes it easier to perform asynchronous preprocessing of each image in | |
TensorFlow. | |
Running this script using 16 threads may take around 1 hour on a HP Z420. | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
from collections import Counter | |
from collections import namedtuple | |
from datetime import datetime | |
import json | |
import os.path | |
import random | |
import sys | |
import threading | |
import nltk.tokenize | |
import numpy as np | |
import tensorflow as tf | |
tf.flags.DEFINE_string("train_image_dir", "/tmp/train2014/", | |
"Training image directory.") | |
tf.flags.DEFINE_string("val_image_dir", "/tmp/val2014", | |
"Validation image directory.") | |
tf.flags.DEFINE_string("train_captions_file", "/tmp/captions_train2014.json", | |
"Training captions JSON file.") | |
tf.flags.DEFINE_string("val_captions_file", "/tmp/captions_val2014.json", | |
"Validation captions JSON file.") | |
tf.flags.DEFINE_string("output_dir", "/tmp/", "Output data directory.") | |
tf.flags.DEFINE_integer("train_shards", 256, | |
"Number of shards in training TFRecord files.") | |
tf.flags.DEFINE_integer("val_shards", 4, | |
"Number of shards in validation TFRecord files.") | |
tf.flags.DEFINE_integer("test_shards", 8, | |
"Number of shards in testing TFRecord files.") | |
tf.flags.DEFINE_string("start_word", "<S>", | |
"Special word added to the beginning of each sentence.") | |
tf.flags.DEFINE_string("end_word", "</S>", | |
"Special word added to the end of each sentence.") | |
tf.flags.DEFINE_string("unknown_word", "<UNK>", | |
"Special word meaning 'unknown'.") | |
tf.flags.DEFINE_integer("min_word_count", 4, | |
"The minimum number of occurrences of each word in the " | |
"training set for inclusion in the vocabulary.") | |
tf.flags.DEFINE_string("word_counts_output_file", "/tmp/word_counts.txt", | |
"Output vocabulary file of word counts.") | |
tf.flags.DEFINE_integer("num_threads", 8, | |
"Number of threads to preprocess the images.") | |
FLAGS = tf.flags.FLAGS | |
ImageMetadata = namedtuple("ImageMetadata", | |
["image_id", "filename", "captions"]) | |
class Vocabulary(object): | |
"""Simple vocabulary wrapper.""" | |
def __init__(self, vocab, unk_id): | |
"""Initializes the vocabulary. | |
Args: | |
vocab: A dictionary of word to word_id. | |
unk_id: Id of the special 'unknown' word. | |
""" | |
self._vocab = vocab | |
self._unk_id = unk_id | |
def word_to_id(self, word): | |
"""Returns the integer id of a word string.""" | |
if word in self._vocab: | |
return self._vocab[word] | |
else: | |
return self._unk_id | |
class ImageDecoder(object): | |
"""Helper class for decoding images in TensorFlow.""" | |
def __init__(self): | |
# Create a single TensorFlow Session for all image decoding calls. | |
self._sess = tf.Session() | |
# TensorFlow ops for JPEG decoding. | |
self._encoded_jpeg = tf.placeholder(dtype=tf.string) | |
self._decode_jpeg = tf.image.decode_jpeg(self._encoded_jpeg, channels=3) | |
def decode_jpeg(self, encoded_jpeg): | |
image = self._sess.run(self._decode_jpeg, | |
feed_dict={self._encoded_jpeg: encoded_jpeg}) | |
assert len(image.shape) == 3 | |
assert image.shape[2] == 3 | |
return image | |
def _int64_feature(value): | |
"""Wrapper for inserting an int64 Feature into a SequenceExample proto.""" | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) | |
def _bytes_feature(value): | |
"""Wrapper for inserting a bytes Feature into a SequenceExample proto.""" | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | |
def _int64_feature_list(values): | |
"""Wrapper for inserting an int64 FeatureList into a SequenceExample proto.""" | |
return tf.train.FeatureList(feature=[_int64_feature(v) for v in values]) | |
def _bytes_feature_list(values): | |
"""Wrapper for inserting a bytes FeatureList into a SequenceExample proto.""" | |
return tf.train.FeatureList(feature=[_bytes_feature(str.encode(v)) for v in values]) | |
def _to_sequence_example(image, decoder, vocab): | |
"""Builds a SequenceExample proto for an image-caption pair. | |
Args: | |
image: An ImageMetadata object. | |
decoder: An ImageDecoder object. | |
vocab: A Vocabulary object. | |
Returns: | |
A SequenceExample proto. | |
""" | |
with tf.gfile.FastGFile(image.filename, "r") as f: | |
encoded_image = f.read() | |
try: | |
decoder.decode_jpeg(encoded_image) | |
except (tf.errors.InvalidArgumentError, AssertionError): | |
print("Skipping file with invalid JPEG data: %s" % image.filename) | |
return | |
context = tf.train.Features(feature={ | |
"image/image_id": _int64_feature(image.image_id), | |
"image/data": _bytes_feature(encoded_image), | |
}) | |
assert len(image.captions) == 1 | |
caption = image.captions[0] | |
caption_ids = [vocab.word_to_id(word) for word in caption] | |
print(caption) | |
feature_lists = tf.train.FeatureLists(feature_list={ | |
"image/caption": _bytes_feature_list(caption), | |
"image/caption_ids": _int64_feature_list(caption_ids) | |
}) | |
sequence_example = tf.train.SequenceExample( | |
context=context, feature_lists=feature_lists) | |
return sequence_example | |
def _process_image_files(thread_index, ranges, name, images, decoder, vocab, | |
num_shards): | |
"""Processes and saves a subset of images as TFRecord files in one thread. | |
Args: | |
thread_index: Integer thread identifier within [0, len(ranges)]. | |
ranges: A list of pairs of integers specifying the ranges of the dataset to | |
process in parallel. | |
name: Unique identifier specifying the dataset. | |
images: List of ImageMetadata. | |
decoder: An ImageDecoder object. | |
vocab: A Vocabulary object. | |
num_shards: Integer number of shards for the output files. | |
""" | |
# Each thread produces N shards where N = num_shards / num_threads. For | |
# instance, if num_shards = 128, and num_threads = 2, then the first thread | |
# would produce shards [0, 64). | |
num_threads = len(ranges) | |
assert not num_shards % num_threads | |
num_shards_per_batch = int(num_shards / num_threads) | |
shard_ranges = np.linspace(ranges[thread_index][0], ranges[thread_index][1], | |
num_shards_per_batch + 1).astype(int) | |
num_images_in_thread = ranges[thread_index][1] - ranges[thread_index][0] | |
counter = 0 | |
for s in range(num_shards_per_batch): | |
# Generate a sharded version of the file name, e.g. 'train-00002-of-00010' | |
shard = thread_index * num_shards_per_batch + s | |
output_filename = "%s-%.5d-of-%.5d" % (name, shard, num_shards) | |
output_file = os.path.join(FLAGS.output_dir, output_filename) | |
writer = tf.python_io.TFRecordWriter(output_file) | |
shard_counter = 0 | |
images_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int) | |
for i in images_in_shard: | |
image = images[i] | |
sequence_example = _to_sequence_example(image, decoder, vocab) | |
if sequence_example is not None: | |
writer.write(sequence_example.SerializeToString()) | |
shard_counter += 1 | |
counter += 1 | |
if not counter % 1000: | |
print("%s [thread %d]: Processed %d of %d items in thread batch." % | |
(datetime.now(), thread_index, counter, num_images_in_thread)) | |
sys.stdout.flush() | |
writer.close() | |
print("%s [thread %d]: Wrote %d image-caption pairs to %s" % | |
(datetime.now(), thread_index, shard_counter, output_file)) | |
sys.stdout.flush() | |
shard_counter = 0 | |
print("%s [thread %d]: Wrote %d image-caption pairs to %d shards." % | |
(datetime.now(), thread_index, counter, num_shards_per_batch)) | |
sys.stdout.flush() | |
def _process_dataset(name, images, vocab, num_shards): | |
"""Processes a complete data set and saves it as a TFRecord. | |
Args: | |
name: Unique identifier specifying the dataset. | |
images: List of ImageMetadata. | |
vocab: A Vocabulary object. | |
num_shards: Integer number of shards for the output files. | |
""" | |
# Break up each image into a separate entity for each caption. | |
images = [ImageMetadata(image.image_id, image.filename, [caption]) | |
for image in images for caption in image.captions] | |
# Shuffle the ordering of images. Make the randomization repeatable. | |
random.seed(12345) | |
random.shuffle(images) | |
# Break the images into num_threads batches. Batch i is defined as | |
# images[ranges[i][0]:ranges[i][1]]. | |
num_threads = min(num_shards, FLAGS.num_threads) | |
spacing = np.linspace(0, len(images), num_threads + 1).astype(np.int) | |
ranges = [] | |
threads = [] | |
for i in range(len(spacing) - 1): | |
ranges.append([spacing[i], spacing[i + 1]]) | |
# Create a mechanism for monitoring when all threads are finished. | |
coord = tf.train.Coordinator() | |
# Create a utility for decoding JPEG images to run sanity checks. | |
decoder = ImageDecoder() | |
# Launch a thread for each batch. | |
print("Launching %d threads for spacings: %s" % (num_threads, ranges)) | |
for thread_index in range(len(ranges)): | |
args = (thread_index, ranges, name, images, decoder, vocab, num_shards) | |
t = threading.Thread(target=_process_image_files, args=args) | |
t.start() | |
threads.append(t) | |
# Wait for all the threads to terminate. | |
coord.join(threads) | |
print("%s: Finished processing all %d image-caption pairs in data set '%s'." % | |
(datetime.now(), len(images), name)) | |
def _create_vocab(captions): | |
"""Creates the vocabulary of word to word_id. | |
The vocabulary is saved to disk in a text file of word counts. The id of each | |
word in the file is its corresponding 0-based line number. | |
Args: | |
captions: A list of lists of strings. | |
Returns: | |
A Vocabulary object. | |
""" | |
print("Creating vocabulary.") | |
counter = Counter() | |
for c in captions: | |
counter.update(c) | |
print("Total words:", len(counter)) | |
# Filter uncommon words and sort by descending count. | |
word_counts = [x for x in counter.items() if x[1] >= FLAGS.min_word_count] | |
word_counts.sort(key=lambda x: x[1], reverse=True) | |
print("Words in vocabulary:", len(word_counts)) | |
# Write out the word counts file. | |
with tf.gfile.FastGFile(FLAGS.word_counts_output_file, "w") as f: | |
f.write("\n".join(["%s %d" % (w, c) for w, c in word_counts])) | |
print("Wrote vocabulary file:", FLAGS.word_counts_output_file) | |
# Create the vocabulary dictionary. | |
reverse_vocab = [x[0] for x in word_counts] | |
unk_id = len(reverse_vocab) | |
vocab_dict = dict([(x, y) for (y, x) in enumerate(reverse_vocab)]) | |
vocab = Vocabulary(vocab_dict, unk_id) | |
return vocab | |
def _process_caption(caption): | |
"""Processes a caption string into a list of tonenized words. | |
Args: | |
caption: A string caption. | |
Returns: | |
A list of strings; the tokenized caption. | |
""" | |
tokenized_caption = [FLAGS.start_word] | |
tokenized_caption.extend(nltk.tokenize.word_tokenize(caption.lower())) | |
tokenized_caption.append(FLAGS.end_word) | |
return tokenized_caption | |
def _load_and_process_metadata(captions_file, image_dir): | |
"""Loads image metadata from a JSON file and processes the captions. | |
Args: | |
captions_file: JSON file containing caption annotations. | |
image_dir: Directory containing the image files. | |
Returns: | |
A list of ImageMetadata. | |
""" | |
with tf.gfile.FastGFile(captions_file, "r") as f: | |
caption_data = json.loads(str(f.read(), encoding="utf-8")) | |
# Extract the filenames. | |
id_to_filename = [(x["id"], x["file_name"]) for x in caption_data["images"]] | |
# Extract the captions. Each image_id is associated with multiple captions. | |
id_to_captions = {} | |
for annotation in caption_data["annotations"]: | |
image_id = annotation["image_id"] | |
caption = annotation["caption"] | |
id_to_captions.setdefault(image_id, []) | |
id_to_captions[image_id].append(caption) | |
assert len(id_to_filename) == len(id_to_captions) | |
assert set([x[0] for x in id_to_filename]) == set(id_to_captions.keys()) | |
print("Loaded caption metadata for %d images from %s" % | |
(len(id_to_filename), captions_file)) | |
# Process the captions and combine the data into a list of ImageMetadata. | |
print("Proccessing captions.") | |
image_metadata = [] | |
num_captions = 0 | |
for image_id, base_filename in id_to_filename: | |
filename = os.path.join(image_dir, base_filename) | |
captions = [_process_caption(c) for c in id_to_captions[image_id]] | |
image_metadata.append(ImageMetadata(image_id, filename, captions)) | |
num_captions += len(captions) | |
print("Finished processing %d captions for %d images in %s" % | |
(num_captions, len(id_to_filename), captions_file)) | |
return image_metadata | |
def main(unused_argv): | |
def _is_valid_num_shards(num_shards): | |
"""Returns True if num_shards is compatible with FLAGS.num_threads.""" | |
return num_shards < FLAGS.num_threads or not num_shards % FLAGS.num_threads | |
assert _is_valid_num_shards(FLAGS.train_shards), ( | |
"Please make the FLAGS.num_threads commensurate with FLAGS.train_shards") | |
assert _is_valid_num_shards(FLAGS.val_shards), ( | |
"Please make the FLAGS.num_threads commensurate with FLAGS.val_shards") | |
assert _is_valid_num_shards(FLAGS.test_shards), ( | |
"Please make the FLAGS.num_threads commensurate with FLAGS.test_shards") | |
if not tf.gfile.IsDirectory(FLAGS.output_dir): | |
tf.gfile.MakeDirs(FLAGS.output_dir) | |
# Load image metadata from caption files. | |
mscoco_train_dataset = _load_and_process_metadata(FLAGS.train_captions_file, | |
FLAGS.train_image_dir) | |
mscoco_val_dataset = _load_and_process_metadata(FLAGS.val_captions_file, | |
FLAGS.val_image_dir) | |
# Redistribute the MSCOCO data as follows: | |
# train_dataset = 100% of mscoco_train_dataset + 85% of mscoco_val_dataset. | |
# val_dataset = 5% of mscoco_val_dataset (for validation during training). | |
# test_dataset = 10% of mscoco_val_dataset (for final evaluation). | |
train_cutoff = int(0.85 * len(mscoco_val_dataset)) | |
val_cutoff = int(0.90 * len(mscoco_val_dataset)) | |
train_dataset = mscoco_train_dataset + mscoco_val_dataset[0:train_cutoff] | |
val_dataset = mscoco_val_dataset[train_cutoff:val_cutoff] | |
test_dataset = mscoco_val_dataset[val_cutoff:] | |
# Create vocabulary from the training captions. | |
train_captions = [c for image in train_dataset for c in image.captions] | |
vocab = _create_vocab(train_captions) | |
_process_dataset("train", train_dataset, vocab, FLAGS.train_shards) | |
_process_dataset("val", val_dataset, vocab, FLAGS.val_shards) | |
_process_dataset("test", test_dataset, vocab, FLAGS.test_shards) | |
if __name__ == "__main__": | |
tf.app.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Image-to-text implementation based on http://arxiv.org/abs/1411.4555. | |
"Show and Tell: A Neural Image Caption Generator" | |
Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tensorflow as tf | |
from im2txt.ops import image_embedding | |
from im2txt.ops import image_processing | |
from im2txt.ops import inputs as input_ops | |
class ShowAndTellModel(object): | |
"""Image-to-text implementation based on http://arxiv.org/abs/1411.4555. | |
"Show and Tell: A Neural Image Caption Generator" | |
Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan | |
""" | |
def __init__(self, config, mode, train_inception=False): | |
"""Basic setup. | |
Args: | |
config: Object containing configuration parameters. | |
mode: "train", "eval" or "inference". | |
train_inception: Whether the inception submodel variables are trainable. | |
""" | |
assert mode in ["train", "eval", "inference"] | |
self.config = config | |
self.mode = mode | |
self.train_inception = train_inception | |
# Reader for the input data. | |
self.reader = tf.TFRecordReader() | |
# To match the "Show and Tell" paper we initialize all variables with a | |
# random uniform initializer. | |
self.initializer = tf.random_uniform_initializer( | |
minval=-self.config.initializer_scale, | |
maxval=self.config.initializer_scale) | |
# A float32 Tensor with shape [batch_size, height, width, channels]. | |
self.images = None | |
# An int32 Tensor with shape [batch_size, padded_length]. | |
self.input_seqs = None | |
# An int32 Tensor with shape [batch_size, padded_length]. | |
self.target_seqs = None | |
# An int32 0/1 Tensor with shape [batch_size, padded_length]. | |
self.input_mask = None | |
# A float32 Tensor with shape [batch_size, embedding_size]. | |
self.image_embeddings = None | |
# A float32 Tensor with shape [batch_size, padded_length, embedding_size]. | |
self.seq_embeddings = None | |
# A float32 scalar Tensor; the total loss for the trainer to optimize. | |
self.total_loss = None | |
# A float32 Tensor with shape [batch_size * padded_length]. | |
self.target_cross_entropy_losses = None | |
# A float32 Tensor with shape [batch_size * padded_length]. | |
self.target_cross_entropy_loss_weights = None | |
# Collection of variables from the inception submodel. | |
self.inception_variables = [] | |
# Function to restore the inception submodel from checkpoint. | |
self.init_fn = None | |
# Global step Tensor. | |
self.global_step = None | |
def is_training(self): | |
"""Returns true if the model is built for training mode.""" | |
return self.mode == "train" | |
def process_image(self, encoded_image, thread_id=0): | |
"""Decodes and processes an image string. | |
Args: | |
encoded_image: A scalar string Tensor; the encoded image. | |
thread_id: Preprocessing thread id used to select the ordering of color | |
distortions. | |
Returns: | |
A float32 Tensor of shape [height, width, 3]; the processed image. | |
""" | |
return image_processing.process_image(encoded_image, | |
is_training=self.is_training(), | |
height=self.config.image_height, | |
width=self.config.image_width, | |
thread_id=thread_id, | |
image_format=self.config.image_format) | |
def build_inputs(self): | |
"""Input prefetching, preprocessing and batching. | |
Outputs: | |
self.images | |
self.input_seqs | |
self.target_seqs (training and eval only) | |
self.input_mask (training and eval only) | |
""" | |
if self.mode == "inference": | |
# In inference mode, images and inputs are fed via placeholders. | |
image_feed = tf.placeholder(dtype=tf.string, shape=[], name="image_feed") | |
input_feed = tf.placeholder(dtype=tf.int64, | |
shape=[None], # batch_size | |
name="input_feed") | |
# Process image and insert batch dimensions. | |
images = tf.expand_dims(self.process_image(image_feed), 0) | |
input_seqs = tf.expand_dims(input_feed, 1) | |
# No target sequences or input mask in inference mode. | |
target_seqs = None | |
input_mask = None | |
else: | |
# Prefetch serialized SequenceExample protos. | |
input_queue = input_ops.prefetch_input_data( | |
self.reader, | |
self.config.input_file_pattern, | |
is_training=self.is_training(), | |
batch_size=self.config.batch_size, | |
values_per_shard=self.config.values_per_input_shard, | |
input_queue_capacity_factor=self.config.input_queue_capacity_factor, | |
num_reader_threads=self.config.num_input_reader_threads) | |
# Image processing and random distortion. Split across multiple threads | |
# with each thread applying a slightly different distortion. | |
assert self.config.num_preprocess_threads % 2 == 0 | |
images_and_captions = [] | |
for thread_id in range(self.config.num_preprocess_threads): | |
serialized_sequence_example = input_queue.dequeue() | |
encoded_image, caption = input_ops.parse_sequence_example( | |
serialized_sequence_example, | |
image_feature=self.config.image_feature_name, | |
caption_feature=self.config.caption_feature_name) | |
image = self.process_image(encoded_image, thread_id=thread_id) | |
images_and_captions.append([image, caption]) | |
# Batch inputs. | |
queue_capacity = (2 * self.config.num_preprocess_threads * | |
self.config.batch_size) | |
images, input_seqs, target_seqs, input_mask = ( | |
input_ops.batch_with_dynamic_pad(images_and_captions, | |
batch_size=self.config.batch_size, | |
queue_capacity=queue_capacity)) | |
self.images = images | |
self.input_seqs = input_seqs | |
self.target_seqs = target_seqs | |
self.input_mask = input_mask | |
def build_image_embeddings(self): | |
"""Builds the image model subgraph and generates image embeddings. | |
Inputs: | |
self.images | |
Outputs: | |
self.image_embeddings | |
""" | |
inception_output = image_embedding.inception_v3( | |
self.images, | |
trainable=self.train_inception, | |
is_training=self.is_training()) | |
self.inception_variables = tf.get_collection( | |
tf.GraphKeys.GLOBAL_VARIABLES, scope="InceptionV3") | |
# Map inception output into embedding space. | |
with tf.variable_scope("image_embedding") as scope: | |
image_embeddings = tf.contrib.layers.fully_connected( | |
inputs=inception_output, | |
num_outputs=self.config.embedding_size, | |
activation_fn=None, | |
weights_initializer=self.initializer, | |
biases_initializer=None, | |
scope=scope) | |
# Save the embedding size in the graph. | |
tf.constant(self.config.embedding_size, name="embedding_size") | |
self.image_embeddings = image_embeddings | |
def build_seq_embeddings(self): | |
"""Builds the input sequence embeddings. | |
Inputs: | |
self.input_seqs | |
Outputs: | |
self.seq_embeddings | |
""" | |
with tf.variable_scope("seq_embedding"), tf.device("/cpu:0"): | |
embedding_map = tf.get_variable( | |
name="map", | |
shape=[self.config.vocab_size, self.config.embedding_size], | |
initializer=self.initializer) | |
seq_embeddings = tf.nn.embedding_lookup(embedding_map, self.input_seqs) | |
self.seq_embeddings = seq_embeddings | |
def build_model(self): | |
"""Builds the model. | |
Inputs: | |
self.image_embeddings | |
self.seq_embeddings | |
self.target_seqs (training and eval only) | |
self.input_mask (training and eval only) | |
Outputs: | |
self.total_loss (training and eval only) | |
self.target_cross_entropy_losses (training and eval only) | |
self.target_cross_entropy_loss_weights (training and eval only) | |
""" | |
# This LSTM cell has biases and outputs tanh(new_c) * sigmoid(o), but the | |
# modified LSTM in the "Show and Tell" paper has no biases and outputs | |
# new_c * sigmoid(o). | |
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell( | |
num_units=self.config.num_lstm_units, state_is_tuple=True) | |
if self.mode == "train": | |
lstm_cell = tf.nn.rnn_cell.DropoutWrapper( | |
lstm_cell, | |
input_keep_prob=self.config.lstm_dropout_keep_prob, | |
output_keep_prob=self.config.lstm_dropout_keep_prob) | |
with tf.variable_scope("lstm", initializer=self.initializer) as lstm_scope: | |
# Feed the image embeddings to set the initial LSTM state. | |
zero_state = lstm_cell.zero_state( | |
batch_size=self.image_embeddings.get_shape()[0], dtype=tf.float32) | |
_, initial_state = lstm_cell(self.image_embeddings, zero_state) | |
# Allow the LSTM variables to be reused. | |
lstm_scope.reuse_variables() | |
if self.mode == "inference": | |
# In inference mode, use concatenated states for convenient feeding and | |
# fetching. | |
tf.concat_v2(initial_state, 1, name="initial_state") | |
# Placeholder for feeding a batch of concatenated states. | |
state_feed = tf.placeholder(dtype=tf.float32, | |
shape=[None, sum(lstm_cell.state_size)], | |
name="state_feed") | |
state_tuple = tf.split(value=state_feed, num_or_size_splits=2, axis=1) | |
# Run a single LSTM step. | |
lstm_outputs, state_tuple = lstm_cell( | |
inputs=tf.squeeze(self.seq_embeddings, squeeze_dims=[1]), | |
state=state_tuple) | |
# Concatentate the resulting state. | |
tf.concat_v2(state_tuple, 1, name="state") | |
else: | |
# Run the batch of sequence embeddings through the LSTM. | |
sequence_length = tf.reduce_sum(self.input_mask, 1) | |
lstm_outputs, _ = tf.nn.dynamic_rnn(cell=lstm_cell, | |
inputs=self.seq_embeddings, | |
sequence_length=sequence_length, | |
initial_state=initial_state, | |
dtype=tf.float32, | |
scope=lstm_scope) | |
# Stack batches vertically. | |
lstm_outputs = tf.reshape(lstm_outputs, [-1, lstm_cell.output_size]) | |
with tf.variable_scope("logits") as logits_scope: | |
logits = tf.contrib.layers.fully_connected( | |
inputs=lstm_outputs, | |
num_outputs=self.config.vocab_size, | |
activation_fn=None, | |
weights_initializer=self.initializer, | |
scope=logits_scope) | |
if self.mode == "inference": | |
tf.nn.softmax(logits, name="softmax") | |
else: | |
targets = tf.reshape(self.target_seqs, [-1]) | |
weights = tf.to_float(tf.reshape(self.input_mask, [-1])) | |
# Compute losses. | |
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets, | |
logits=logits) | |
batch_loss = tf.div(tf.reduce_sum(tf.multiply(losses, weights)), | |
tf.reduce_sum(weights), | |
name="batch_loss") | |
tf.contrib.losses.add_loss(batch_loss) | |
total_loss = tf.contrib.losses.get_total_loss() | |
# Add summaries. | |
tf.summary.scalar("losses/batch_loss", batch_loss) | |
tf.summary.scalar("losses/total_loss", total_loss) | |
for var in tf.trainable_variables(): | |
tf.summary.histogram("parameters/" + var.op.name, var) | |
self.total_loss = total_loss | |
self.target_cross_entropy_losses = losses # Used in evaluation. | |
self.target_cross_entropy_loss_weights = weights # Used in evaluation. | |
def setup_inception_initializer(self): | |
"""Sets up the function to restore inception variables from checkpoint.""" | |
if self.mode != "inference": | |
# Restore inception variables only. | |
saver = tf.train.Saver(self.inception_variables) | |
def restore_fn(sess): | |
tf.logging.info("Restoring Inception variables from checkpoint file %s", | |
self.config.inception_checkpoint_file) | |
saver.restore(sess, self.config.inception_checkpoint_file) | |
self.init_fn = restore_fn | |
def setup_global_step(self): | |
"""Sets up the global step Tensor.""" | |
global_step = tf.Variable( | |
initial_value=0, | |
name="global_step", | |
trainable=False, | |
collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES]) | |
self.global_step = global_step | |
def build(self): | |
"""Creates all ops for training and evaluation.""" | |
self.build_inputs() | |
self.build_image_embeddings() | |
self.build_seq_embeddings() | |
self.build_model() | |
self.setup_inception_initializer() | |
self.setup_global_step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment