Skip to content

Instantly share code, notes, and snippets.

@inyl
Last active July 1, 2020 18:08
Show Gist options
  • Save inyl/5679509f1211a80d0abea149338d49c8 to your computer and use it in GitHub Desktop.
Save inyl/5679509f1211a80d0abea149338d49c8 to your computer and use it in GitHub Desktop.
tensor models im2txt python3 version
package(default_visibility = [":internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package_group(
name = "internal",
packages = [
"//im2txt/...",
],
)
py_binary(
name = "build_mscoco_data",
default_python_version = "PY3",
srcs_version = "PY3",
srcs = [
"data/build_mscoco_data.py",
],
)
sh_binary(
name = "download_and_preprocess_mscoco",
srcs = ["data/download_and_preprocess_mscoco.sh"],
data = [
":build_mscoco_data",
],
)
py_library(
name = "configuration",
srcs = ["configuration.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "show_and_tell_model",
srcs = ["show_and_tell_model.py"],
srcs_version = "PY2AND3",
deps = [
"//im2txt/ops:image_embedding",
"//im2txt/ops:image_processing",
"//im2txt/ops:inputs",
],
)
py_test(
name = "show_and_tell_model_test",
size = "large",
srcs = ["show_and_tell_model_test.py"],
deps = [
":configuration",
":show_and_tell_model",
],
)
py_library(
name = "inference_wrapper",
srcs = ["inference_wrapper.py"],
srcs_version = "PY2AND3",
deps = [
":show_and_tell_model",
"//im2txt/inference_utils:inference_wrapper_base",
],
)
py_binary(
name = "train",
srcs = ["train.py"],
srcs_version = "PY3",
default_python_version = "PY3",
deps = [
":configuration",
":show_and_tell_model",
],
)
py_binary(
name = "evaluate",
srcs = ["evaluate.py"],
srcs_version = "PY3",
default_python_version = "PY3",
deps = [
":configuration",
":show_and_tell_model",
],
)
py_binary(
name = "run_inference",
srcs = ["run_inference.py"],
srcs_version = "PY3",
default_python_version = "PY3",
deps = [
":configuration",
":inference_wrapper",
"//im2txt/inference_utils:caption_generator",
"//im2txt/inference_utils:vocabulary",
],
)
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts MSCOCO data to TFRecord file format with SequenceExample protos.
The MSCOCO images are expected to reside in JPEG files located in the following
directory structure:
train_image_dir/COCO_train2014_000000000151.jpg
train_image_dir/COCO_train2014_000000000260.jpg
...
and
val_image_dir/COCO_val2014_000000000042.jpg
val_image_dir/COCO_val2014_000000000073.jpg
...
The MSCOCO annotations JSON files are expected to reside in train_captions_file
and val_captions_file respectively.
This script converts the combined MSCOCO data into sharded data files consisting
of 256, 4 and 8 TFRecord files, respectively:
output_dir/train-00000-of-00256
output_dir/train-00001-of-00256
...
output_dir/train-00255-of-00256
and
output_dir/val-00000-of-00004
...
output_dir/val-00003-of-00004
and
output_dir/test-00000-of-00008
...
output_dir/test-00007-of-00008
Each TFRecord file contains ~2300 records. Each record within the TFRecord file
is a serialized SequenceExample proto consisting of precisely one image-caption
pair. Note that each image has multiple captions (usually 5) and therefore each
image is replicated multiple times in the TFRecord files.
The SequenceExample proto contains the following fields:
context:
image/image_id: integer MSCOCO image identifier
image/data: string containing JPEG encoded image in RGB colorspace
feature_lists:
image/caption: list of strings containing the (tokenized) caption words
image/caption_ids: list of integer ids corresponding to the caption words
The captions are tokenized using the NLTK (http://www.nltk.org/) word tokenizer.
The vocabulary of word identifiers is constructed from the sorted list (by
descending frequency) of word tokens in the training set. Only tokens appearing
at least 4 times are considered; all other words get the "unknown" word id.
NOTE: This script will consume around 100GB of disk space because each image
in the MSCOCO dataset is replicated ~5 times (once per caption) in the output.
This is done for two reasons:
1. In order to better shuffle the training data.
2. It makes it easier to perform asynchronous preprocessing of each image in
TensorFlow.
Running this script using 16 threads may take around 1 hour on a HP Z420.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import Counter
from collections import namedtuple
from datetime import datetime
import json
import os.path
import random
import sys
import threading
import nltk.tokenize
import numpy as np
import tensorflow as tf
tf.flags.DEFINE_string("train_image_dir", "/tmp/train2014/",
"Training image directory.")
tf.flags.DEFINE_string("val_image_dir", "/tmp/val2014",
"Validation image directory.")
tf.flags.DEFINE_string("train_captions_file", "/tmp/captions_train2014.json",
"Training captions JSON file.")
tf.flags.DEFINE_string("val_captions_file", "/tmp/captions_val2014.json",
"Validation captions JSON file.")
tf.flags.DEFINE_string("output_dir", "/tmp/", "Output data directory.")
tf.flags.DEFINE_integer("train_shards", 256,
"Number of shards in training TFRecord files.")
tf.flags.DEFINE_integer("val_shards", 4,
"Number of shards in validation TFRecord files.")
tf.flags.DEFINE_integer("test_shards", 8,
"Number of shards in testing TFRecord files.")
tf.flags.DEFINE_string("start_word", "<S>",
"Special word added to the beginning of each sentence.")
tf.flags.DEFINE_string("end_word", "</S>",
"Special word added to the end of each sentence.")
tf.flags.DEFINE_string("unknown_word", "<UNK>",
"Special word meaning 'unknown'.")
tf.flags.DEFINE_integer("min_word_count", 4,
"The minimum number of occurrences of each word in the "
"training set for inclusion in the vocabulary.")
tf.flags.DEFINE_string("word_counts_output_file", "/tmp/word_counts.txt",
"Output vocabulary file of word counts.")
tf.flags.DEFINE_integer("num_threads", 8,
"Number of threads to preprocess the images.")
FLAGS = tf.flags.FLAGS
ImageMetadata = namedtuple("ImageMetadata",
["image_id", "filename", "captions"])
class Vocabulary(object):
"""Simple vocabulary wrapper."""
def __init__(self, vocab, unk_id):
"""Initializes the vocabulary.
Args:
vocab: A dictionary of word to word_id.
unk_id: Id of the special 'unknown' word.
"""
self._vocab = vocab
self._unk_id = unk_id
def word_to_id(self, word):
"""Returns the integer id of a word string."""
if word in self._vocab:
return self._vocab[word]
else:
return self._unk_id
class ImageDecoder(object):
"""Helper class for decoding images in TensorFlow."""
def __init__(self):
# Create a single TensorFlow Session for all image decoding calls.
self._sess = tf.Session()
# TensorFlow ops for JPEG decoding.
self._encoded_jpeg = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._encoded_jpeg, channels=3)
def decode_jpeg(self, encoded_jpeg):
image = self._sess.run(self._decode_jpeg,
feed_dict={self._encoded_jpeg: encoded_jpeg})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
def _int64_feature(value):
"""Wrapper for inserting an int64 Feature into a SequenceExample proto."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
"""Wrapper for inserting a bytes Feature into a SequenceExample proto."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature_list(values):
"""Wrapper for inserting an int64 FeatureList into a SequenceExample proto."""
return tf.train.FeatureList(feature=[_int64_feature(v) for v in values])
def _bytes_feature_list(values):
"""Wrapper for inserting a bytes FeatureList into a SequenceExample proto."""
return tf.train.FeatureList(feature=[_bytes_feature(str.encode(v)) for v in values])
def _to_sequence_example(image, decoder, vocab):
"""Builds a SequenceExample proto for an image-caption pair.
Args:
image: An ImageMetadata object.
decoder: An ImageDecoder object.
vocab: A Vocabulary object.
Returns:
A SequenceExample proto.
"""
with tf.gfile.FastGFile(image.filename, "r") as f:
encoded_image = f.read()
try:
decoder.decode_jpeg(encoded_image)
except (tf.errors.InvalidArgumentError, AssertionError):
print("Skipping file with invalid JPEG data: %s" % image.filename)
return
context = tf.train.Features(feature={
"image/image_id": _int64_feature(image.image_id),
"image/data": _bytes_feature(encoded_image),
})
assert len(image.captions) == 1
caption = image.captions[0]
caption_ids = [vocab.word_to_id(word) for word in caption]
print(caption)
feature_lists = tf.train.FeatureLists(feature_list={
"image/caption": _bytes_feature_list(caption),
"image/caption_ids": _int64_feature_list(caption_ids)
})
sequence_example = tf.train.SequenceExample(
context=context, feature_lists=feature_lists)
return sequence_example
def _process_image_files(thread_index, ranges, name, images, decoder, vocab,
num_shards):
"""Processes and saves a subset of images as TFRecord files in one thread.
Args:
thread_index: Integer thread identifier within [0, len(ranges)].
ranges: A list of pairs of integers specifying the ranges of the dataset to
process in parallel.
name: Unique identifier specifying the dataset.
images: List of ImageMetadata.
decoder: An ImageDecoder object.
vocab: A Vocabulary object.
num_shards: Integer number of shards for the output files.
"""
# Each thread produces N shards where N = num_shards / num_threads. For
# instance, if num_shards = 128, and num_threads = 2, then the first thread
# would produce shards [0, 64).
num_threads = len(ranges)
assert not num_shards % num_threads
num_shards_per_batch = int(num_shards / num_threads)
shard_ranges = np.linspace(ranges[thread_index][0], ranges[thread_index][1],
num_shards_per_batch + 1).astype(int)
num_images_in_thread = ranges[thread_index][1] - ranges[thread_index][0]
counter = 0
for s in range(num_shards_per_batch):
# Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
shard = thread_index * num_shards_per_batch + s
output_filename = "%s-%.5d-of-%.5d" % (name, shard, num_shards)
output_file = os.path.join(FLAGS.output_dir, output_filename)
writer = tf.python_io.TFRecordWriter(output_file)
shard_counter = 0
images_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
for i in images_in_shard:
image = images[i]
sequence_example = _to_sequence_example(image, decoder, vocab)
if sequence_example is not None:
writer.write(sequence_example.SerializeToString())
shard_counter += 1
counter += 1
if not counter % 1000:
print("%s [thread %d]: Processed %d of %d items in thread batch." %
(datetime.now(), thread_index, counter, num_images_in_thread))
sys.stdout.flush()
writer.close()
print("%s [thread %d]: Wrote %d image-caption pairs to %s" %
(datetime.now(), thread_index, shard_counter, output_file))
sys.stdout.flush()
shard_counter = 0
print("%s [thread %d]: Wrote %d image-caption pairs to %d shards." %
(datetime.now(), thread_index, counter, num_shards_per_batch))
sys.stdout.flush()
def _process_dataset(name, images, vocab, num_shards):
"""Processes a complete data set and saves it as a TFRecord.
Args:
name: Unique identifier specifying the dataset.
images: List of ImageMetadata.
vocab: A Vocabulary object.
num_shards: Integer number of shards for the output files.
"""
# Break up each image into a separate entity for each caption.
images = [ImageMetadata(image.image_id, image.filename, [caption])
for image in images for caption in image.captions]
# Shuffle the ordering of images. Make the randomization repeatable.
random.seed(12345)
random.shuffle(images)
# Break the images into num_threads batches. Batch i is defined as
# images[ranges[i][0]:ranges[i][1]].
num_threads = min(num_shards, FLAGS.num_threads)
spacing = np.linspace(0, len(images), num_threads + 1).astype(np.int)
ranges = []
threads = []
for i in range(len(spacing) - 1):
ranges.append([spacing[i], spacing[i + 1]])
# Create a mechanism for monitoring when all threads are finished.
coord = tf.train.Coordinator()
# Create a utility for decoding JPEG images to run sanity checks.
decoder = ImageDecoder()
# Launch a thread for each batch.
print("Launching %d threads for spacings: %s" % (num_threads, ranges))
for thread_index in range(len(ranges)):
args = (thread_index, ranges, name, images, decoder, vocab, num_shards)
t = threading.Thread(target=_process_image_files, args=args)
t.start()
threads.append(t)
# Wait for all the threads to terminate.
coord.join(threads)
print("%s: Finished processing all %d image-caption pairs in data set '%s'." %
(datetime.now(), len(images), name))
def _create_vocab(captions):
"""Creates the vocabulary of word to word_id.
The vocabulary is saved to disk in a text file of word counts. The id of each
word in the file is its corresponding 0-based line number.
Args:
captions: A list of lists of strings.
Returns:
A Vocabulary object.
"""
print("Creating vocabulary.")
counter = Counter()
for c in captions:
counter.update(c)
print("Total words:", len(counter))
# Filter uncommon words and sort by descending count.
word_counts = [x for x in counter.items() if x[1] >= FLAGS.min_word_count]
word_counts.sort(key=lambda x: x[1], reverse=True)
print("Words in vocabulary:", len(word_counts))
# Write out the word counts file.
with tf.gfile.FastGFile(FLAGS.word_counts_output_file, "w") as f:
f.write("\n".join(["%s %d" % (w, c) for w, c in word_counts]))
print("Wrote vocabulary file:", FLAGS.word_counts_output_file)
# Create the vocabulary dictionary.
reverse_vocab = [x[0] for x in word_counts]
unk_id = len(reverse_vocab)
vocab_dict = dict([(x, y) for (y, x) in enumerate(reverse_vocab)])
vocab = Vocabulary(vocab_dict, unk_id)
return vocab
def _process_caption(caption):
"""Processes a caption string into a list of tonenized words.
Args:
caption: A string caption.
Returns:
A list of strings; the tokenized caption.
"""
tokenized_caption = [FLAGS.start_word]
tokenized_caption.extend(nltk.tokenize.word_tokenize(caption.lower()))
tokenized_caption.append(FLAGS.end_word)
return tokenized_caption
def _load_and_process_metadata(captions_file, image_dir):
"""Loads image metadata from a JSON file and processes the captions.
Args:
captions_file: JSON file containing caption annotations.
image_dir: Directory containing the image files.
Returns:
A list of ImageMetadata.
"""
with tf.gfile.FastGFile(captions_file, "r") as f:
caption_data = json.loads(str(f.read(), encoding="utf-8"))
# Extract the filenames.
id_to_filename = [(x["id"], x["file_name"]) for x in caption_data["images"]]
# Extract the captions. Each image_id is associated with multiple captions.
id_to_captions = {}
for annotation in caption_data["annotations"]:
image_id = annotation["image_id"]
caption = annotation["caption"]
id_to_captions.setdefault(image_id, [])
id_to_captions[image_id].append(caption)
assert len(id_to_filename) == len(id_to_captions)
assert set([x[0] for x in id_to_filename]) == set(id_to_captions.keys())
print("Loaded caption metadata for %d images from %s" %
(len(id_to_filename), captions_file))
# Process the captions and combine the data into a list of ImageMetadata.
print("Proccessing captions.")
image_metadata = []
num_captions = 0
for image_id, base_filename in id_to_filename:
filename = os.path.join(image_dir, base_filename)
captions = [_process_caption(c) for c in id_to_captions[image_id]]
image_metadata.append(ImageMetadata(image_id, filename, captions))
num_captions += len(captions)
print("Finished processing %d captions for %d images in %s" %
(num_captions, len(id_to_filename), captions_file))
return image_metadata
def main(unused_argv):
def _is_valid_num_shards(num_shards):
"""Returns True if num_shards is compatible with FLAGS.num_threads."""
return num_shards < FLAGS.num_threads or not num_shards % FLAGS.num_threads
assert _is_valid_num_shards(FLAGS.train_shards), (
"Please make the FLAGS.num_threads commensurate with FLAGS.train_shards")
assert _is_valid_num_shards(FLAGS.val_shards), (
"Please make the FLAGS.num_threads commensurate with FLAGS.val_shards")
assert _is_valid_num_shards(FLAGS.test_shards), (
"Please make the FLAGS.num_threads commensurate with FLAGS.test_shards")
if not tf.gfile.IsDirectory(FLAGS.output_dir):
tf.gfile.MakeDirs(FLAGS.output_dir)
# Load image metadata from caption files.
mscoco_train_dataset = _load_and_process_metadata(FLAGS.train_captions_file,
FLAGS.train_image_dir)
mscoco_val_dataset = _load_and_process_metadata(FLAGS.val_captions_file,
FLAGS.val_image_dir)
# Redistribute the MSCOCO data as follows:
# train_dataset = 100% of mscoco_train_dataset + 85% of mscoco_val_dataset.
# val_dataset = 5% of mscoco_val_dataset (for validation during training).
# test_dataset = 10% of mscoco_val_dataset (for final evaluation).
train_cutoff = int(0.85 * len(mscoco_val_dataset))
val_cutoff = int(0.90 * len(mscoco_val_dataset))
train_dataset = mscoco_train_dataset + mscoco_val_dataset[0:train_cutoff]
val_dataset = mscoco_val_dataset[train_cutoff:val_cutoff]
test_dataset = mscoco_val_dataset[val_cutoff:]
# Create vocabulary from the training captions.
train_captions = [c for image in train_dataset for c in image.captions]
vocab = _create_vocab(train_captions)
_process_dataset("train", train_dataset, vocab, FLAGS.train_shards)
_process_dataset("val", val_dataset, vocab, FLAGS.val_shards)
_process_dataset("test", test_dataset, vocab, FLAGS.test_shards)
if __name__ == "__main__":
tf.app.run()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image-to-text implementation based on http://arxiv.org/abs/1411.4555.
"Show and Tell: A Neural Image Caption Generator"
Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from im2txt.ops import image_embedding
from im2txt.ops import image_processing
from im2txt.ops import inputs as input_ops
class ShowAndTellModel(object):
"""Image-to-text implementation based on http://arxiv.org/abs/1411.4555.
"Show and Tell: A Neural Image Caption Generator"
Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan
"""
def __init__(self, config, mode, train_inception=False):
"""Basic setup.
Args:
config: Object containing configuration parameters.
mode: "train", "eval" or "inference".
train_inception: Whether the inception submodel variables are trainable.
"""
assert mode in ["train", "eval", "inference"]
self.config = config
self.mode = mode
self.train_inception = train_inception
# Reader for the input data.
self.reader = tf.TFRecordReader()
# To match the "Show and Tell" paper we initialize all variables with a
# random uniform initializer.
self.initializer = tf.random_uniform_initializer(
minval=-self.config.initializer_scale,
maxval=self.config.initializer_scale)
# A float32 Tensor with shape [batch_size, height, width, channels].
self.images = None
# An int32 Tensor with shape [batch_size, padded_length].
self.input_seqs = None
# An int32 Tensor with shape [batch_size, padded_length].
self.target_seqs = None
# An int32 0/1 Tensor with shape [batch_size, padded_length].
self.input_mask = None
# A float32 Tensor with shape [batch_size, embedding_size].
self.image_embeddings = None
# A float32 Tensor with shape [batch_size, padded_length, embedding_size].
self.seq_embeddings = None
# A float32 scalar Tensor; the total loss for the trainer to optimize.
self.total_loss = None
# A float32 Tensor with shape [batch_size * padded_length].
self.target_cross_entropy_losses = None
# A float32 Tensor with shape [batch_size * padded_length].
self.target_cross_entropy_loss_weights = None
# Collection of variables from the inception submodel.
self.inception_variables = []
# Function to restore the inception submodel from checkpoint.
self.init_fn = None
# Global step Tensor.
self.global_step = None
def is_training(self):
"""Returns true if the model is built for training mode."""
return self.mode == "train"
def process_image(self, encoded_image, thread_id=0):
"""Decodes and processes an image string.
Args:
encoded_image: A scalar string Tensor; the encoded image.
thread_id: Preprocessing thread id used to select the ordering of color
distortions.
Returns:
A float32 Tensor of shape [height, width, 3]; the processed image.
"""
return image_processing.process_image(encoded_image,
is_training=self.is_training(),
height=self.config.image_height,
width=self.config.image_width,
thread_id=thread_id,
image_format=self.config.image_format)
def build_inputs(self):
"""Input prefetching, preprocessing and batching.
Outputs:
self.images
self.input_seqs
self.target_seqs (training and eval only)
self.input_mask (training and eval only)
"""
if self.mode == "inference":
# In inference mode, images and inputs are fed via placeholders.
image_feed = tf.placeholder(dtype=tf.string, shape=[], name="image_feed")
input_feed = tf.placeholder(dtype=tf.int64,
shape=[None], # batch_size
name="input_feed")
# Process image and insert batch dimensions.
images = tf.expand_dims(self.process_image(image_feed), 0)
input_seqs = tf.expand_dims(input_feed, 1)
# No target sequences or input mask in inference mode.
target_seqs = None
input_mask = None
else:
# Prefetch serialized SequenceExample protos.
input_queue = input_ops.prefetch_input_data(
self.reader,
self.config.input_file_pattern,
is_training=self.is_training(),
batch_size=self.config.batch_size,
values_per_shard=self.config.values_per_input_shard,
input_queue_capacity_factor=self.config.input_queue_capacity_factor,
num_reader_threads=self.config.num_input_reader_threads)
# Image processing and random distortion. Split across multiple threads
# with each thread applying a slightly different distortion.
assert self.config.num_preprocess_threads % 2 == 0
images_and_captions = []
for thread_id in range(self.config.num_preprocess_threads):
serialized_sequence_example = input_queue.dequeue()
encoded_image, caption = input_ops.parse_sequence_example(
serialized_sequence_example,
image_feature=self.config.image_feature_name,
caption_feature=self.config.caption_feature_name)
image = self.process_image(encoded_image, thread_id=thread_id)
images_and_captions.append([image, caption])
# Batch inputs.
queue_capacity = (2 * self.config.num_preprocess_threads *
self.config.batch_size)
images, input_seqs, target_seqs, input_mask = (
input_ops.batch_with_dynamic_pad(images_and_captions,
batch_size=self.config.batch_size,
queue_capacity=queue_capacity))
self.images = images
self.input_seqs = input_seqs
self.target_seqs = target_seqs
self.input_mask = input_mask
def build_image_embeddings(self):
"""Builds the image model subgraph and generates image embeddings.
Inputs:
self.images
Outputs:
self.image_embeddings
"""
inception_output = image_embedding.inception_v3(
self.images,
trainable=self.train_inception,
is_training=self.is_training())
self.inception_variables = tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES, scope="InceptionV3")
# Map inception output into embedding space.
with tf.variable_scope("image_embedding") as scope:
image_embeddings = tf.contrib.layers.fully_connected(
inputs=inception_output,
num_outputs=self.config.embedding_size,
activation_fn=None,
weights_initializer=self.initializer,
biases_initializer=None,
scope=scope)
# Save the embedding size in the graph.
tf.constant(self.config.embedding_size, name="embedding_size")
self.image_embeddings = image_embeddings
def build_seq_embeddings(self):
"""Builds the input sequence embeddings.
Inputs:
self.input_seqs
Outputs:
self.seq_embeddings
"""
with tf.variable_scope("seq_embedding"), tf.device("/cpu:0"):
embedding_map = tf.get_variable(
name="map",
shape=[self.config.vocab_size, self.config.embedding_size],
initializer=self.initializer)
seq_embeddings = tf.nn.embedding_lookup(embedding_map, self.input_seqs)
self.seq_embeddings = seq_embeddings
def build_model(self):
"""Builds the model.
Inputs:
self.image_embeddings
self.seq_embeddings
self.target_seqs (training and eval only)
self.input_mask (training and eval only)
Outputs:
self.total_loss (training and eval only)
self.target_cross_entropy_losses (training and eval only)
self.target_cross_entropy_loss_weights (training and eval only)
"""
# This LSTM cell has biases and outputs tanh(new_c) * sigmoid(o), but the
# modified LSTM in the "Show and Tell" paper has no biases and outputs
# new_c * sigmoid(o).
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(
num_units=self.config.num_lstm_units, state_is_tuple=True)
if self.mode == "train":
lstm_cell = tf.nn.rnn_cell.DropoutWrapper(
lstm_cell,
input_keep_prob=self.config.lstm_dropout_keep_prob,
output_keep_prob=self.config.lstm_dropout_keep_prob)
with tf.variable_scope("lstm", initializer=self.initializer) as lstm_scope:
# Feed the image embeddings to set the initial LSTM state.
zero_state = lstm_cell.zero_state(
batch_size=self.image_embeddings.get_shape()[0], dtype=tf.float32)
_, initial_state = lstm_cell(self.image_embeddings, zero_state)
# Allow the LSTM variables to be reused.
lstm_scope.reuse_variables()
if self.mode == "inference":
# In inference mode, use concatenated states for convenient feeding and
# fetching.
tf.concat_v2(initial_state, 1, name="initial_state")
# Placeholder for feeding a batch of concatenated states.
state_feed = tf.placeholder(dtype=tf.float32,
shape=[None, sum(lstm_cell.state_size)],
name="state_feed")
state_tuple = tf.split(value=state_feed, num_or_size_splits=2, axis=1)
# Run a single LSTM step.
lstm_outputs, state_tuple = lstm_cell(
inputs=tf.squeeze(self.seq_embeddings, squeeze_dims=[1]),
state=state_tuple)
# Concatentate the resulting state.
tf.concat_v2(state_tuple, 1, name="state")
else:
# Run the batch of sequence embeddings through the LSTM.
sequence_length = tf.reduce_sum(self.input_mask, 1)
lstm_outputs, _ = tf.nn.dynamic_rnn(cell=lstm_cell,
inputs=self.seq_embeddings,
sequence_length=sequence_length,
initial_state=initial_state,
dtype=tf.float32,
scope=lstm_scope)
# Stack batches vertically.
lstm_outputs = tf.reshape(lstm_outputs, [-1, lstm_cell.output_size])
with tf.variable_scope("logits") as logits_scope:
logits = tf.contrib.layers.fully_connected(
inputs=lstm_outputs,
num_outputs=self.config.vocab_size,
activation_fn=None,
weights_initializer=self.initializer,
scope=logits_scope)
if self.mode == "inference":
tf.nn.softmax(logits, name="softmax")
else:
targets = tf.reshape(self.target_seqs, [-1])
weights = tf.to_float(tf.reshape(self.input_mask, [-1]))
# Compute losses.
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets,
logits=logits)
batch_loss = tf.div(tf.reduce_sum(tf.multiply(losses, weights)),
tf.reduce_sum(weights),
name="batch_loss")
tf.contrib.losses.add_loss(batch_loss)
total_loss = tf.contrib.losses.get_total_loss()
# Add summaries.
tf.summary.scalar("losses/batch_loss", batch_loss)
tf.summary.scalar("losses/total_loss", total_loss)
for var in tf.trainable_variables():
tf.summary.histogram("parameters/" + var.op.name, var)
self.total_loss = total_loss
self.target_cross_entropy_losses = losses # Used in evaluation.
self.target_cross_entropy_loss_weights = weights # Used in evaluation.
def setup_inception_initializer(self):
"""Sets up the function to restore inception variables from checkpoint."""
if self.mode != "inference":
# Restore inception variables only.
saver = tf.train.Saver(self.inception_variables)
def restore_fn(sess):
tf.logging.info("Restoring Inception variables from checkpoint file %s",
self.config.inception_checkpoint_file)
saver.restore(sess, self.config.inception_checkpoint_file)
self.init_fn = restore_fn
def setup_global_step(self):
"""Sets up the global step Tensor."""
global_step = tf.Variable(
initial_value=0,
name="global_step",
trainable=False,
collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES])
self.global_step = global_step
def build(self):
"""Creates all ops for training and evaluation."""
self.build_inputs()
self.build_image_embeddings()
self.build_seq_embeddings()
self.build_model()
self.setup_inception_initializer()
self.setup_global_step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment