Last active
March 11, 2020 04:33
-
-
Save nov05/15e45187d1cd151ab02b2f438c9b1fe6 to your computer and use it in GitHub Desktop.
2020-03-07 CNN-LSTM image captioning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
class EncoderCNN(nn.Module): | |
def __init__(self, embed_size): | |
# super(EncoderCNN, self).__init__() | |
super().__init__() | |
resnet = models.resnet50(pretrained=True) | |
for param in resnet.parameters(): | |
param.requires_grad_(False) | |
# remove the top fully connected layer | |
modules = list(resnet.children())[:-1] | |
self.resnet = nn.Sequential(*modules) | |
self.embed = nn.Linear(resnet.fc.in_features, embed_size) | |
def forward(self, images): | |
features = self.resnet(images) | |
features = features.view(features.size(0), -1) | |
features = self.embed(features) # [batch size, embed size] | |
return features | |
class DecoderRNN(nn.Module): | |
def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1, | |
p_dropout=0.1): | |
super().__init__() | |
self.hidden_size = hidden_size | |
# The decoder will embed the inputs before feeding them to the LSTM. | |
self.embedding = nn.Embedding( | |
num_embeddings=vocab_size, | |
embedding_dim=embed_size, | |
# padding_idx=dictionary.pad(), | |
) | |
self.dropout = nn.Dropout(p=p_dropout) | |
self.lstm = nn.LSTM( | |
# For the first layer we'll concatenate the Encoder's final hidden | |
# state with the embedded target tokens. | |
input_size=embed_size, | |
hidden_size=hidden_size, | |
num_layers=num_layers, | |
batch_first=True) | |
self.linear = nn.Linear(hidden_size, vocab_size) | |
# initialize the hidden state (see code below) | |
self.hidden = self.init_hidden() | |
def init_hidden(self): | |
''' At the start of training, we need to initialize a hidden state; | |
there will be none because the hidden state is formed based on perviously seen data. | |
So, this function defines a hidden state with all random numbers and of a specified size.''' | |
# The axes dimensions are [n_layers, batch_size, hidden_size] | |
return (torch.randn(1, 1, self.hidden_size), | |
torch.randn(1, 1, self.hidden_size)) | |
def forward(self, features, captions): | |
# shape of features: [batch_size, embed_size], e.g. [10, 256] | |
# shape of captions: [batch_size, sequence_size], e.g. [10, 20] | |
# Embed the target sequence, which has been shifted right by one | |
# position and now starts with the image feature vector. | |
# shape of caption embedded: [10, 20, 256] | |
embedded = self.embedding(captions) | |
embedded = self.dropout(embedded) | |
lstm_input = torch.cat((features.unsqueeze(1), embedded), dim=1) | |
lstm_input = lstm_input[:, :-1, :] # remove the last token in the sequence | |
# Get the output and hidden state by passing the lstm over our word embeddings | |
# the lstm takes in our embeddings and hidden state. | |
# LSTM input shape: [batch_size, sequence_size, input_size], e.g. [10, 20, 256] | |
# LSTM output shape: [batch_size, sequence_size, hidden_size], e.g. [10, 20, 512] | |
lstm_output, _ = self.lstm(lstm_input) # LSTM output, hidden state | |
# shape of output: [batch_size, sequence_size, vocab_size], e.g.[10, 20, 8856] | |
output = self.linear(lstm_output) | |
output = F.log_softmax(output, dim=2) | |
return output | |
def sample(self, features, states=None, max_len=20): | |
'''accepts pre-processed image tensor (features) and returns | |
predicted sentence (list of tensor ids of length max_len)''' | |
# Inference: There are multiple approaches that can be used | |
# to generate a sentence given an image, with NIC. The first | |
# one is Sampling where we just sample the first word according | |
# to p1, then provide the corresponding embedding | |
# as input and sample p2, continuing like this until we sample | |
# the special end-of-sentence token or some maximum length. | |
# https://arxiv.org/pdf/1411.4555.pdf | |
# shape of features: torch.Size([1, 256]) | |
# shape of word_embedding [1, 1, 256] | |
lstm_input = features.unsqueeze(1) | |
idxs = [] | |
for _ in range(max_len): | |
lstm_output, states = self.lstm(lstm_input, states) | |
output = self.linear(lstm_output) | |
_, idx = torch.max(output[0][0], 0) | |
idxs.append(idx.item()) | |
# embedding input shape [batch_size, sequence_size] | |
# embedding output shape [batch_size, sequence_size, embed_size] | |
lstm_input = self.embedding(idx.unsqueeze(0).unsqueeze(0)) | |
return idxs | |
def beam_search(self, features, states=None, max_len=20, k=20): | |
'''generate sequence with length=max_len from features''' | |
# The second one is【BeamSearch】: iteratively consider the set | |
# of the k best sentences up to time t as candidates to generate | |
# sentences of size t + 1, and keep only the resulting best k | |
# of them. This better approximates S = arg maxS′ p(S′|I). | |
# We used the BeamSearch approach in the following experi- | |
# ments, with a beam of size 20. Using a beam size of 1 (i.e., | |
# greedy search) did degrade our results by 2 BLEU points on | |
# average. https://arxiv.org/pdf/1411.4555.pdf | |
topk = [[[], .0, None]] # [sequence, score, key_states] | |
states_prev, states_curr = {}, {} | |
lstm_input = features.unsqueeze(1) | |
for _ in range(max_len): | |
candidates = [] | |
for i, (seq, score, key_states) in enumerate(topk): | |
# get decoder output | |
if seq: | |
lstm_input = self.embedding(seq[-1].unsqueeze(0).unsqueeze(0)) | |
states = states_prev[key_states] | |
lstm_output, states = self.lstm(lstm_input, states) | |
# store hidden states | |
states_curr[i] = states | |
# get token probalities | |
output = self.linear(lstm_output) | |
output = F.log_softmax(output, dim=2) | |
output = output[0][0] | |
# calculate scores | |
for (idx, val) in enumerate(output): | |
candidate = [seq+[torch.tensor(idx).to(output.device)], score+val.item(), i] | |
candidates.append(candidate) | |
# update hidden states dictionary | |
states_prev, states_curr = states_curr, {} | |
# order all candidates by score, select k-best | |
topk = sorted(candidates, key=lambda x:x[1], reverse=True)[:k] | |
return [idx.item() for idx in topk[0][0]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
How to Implement a Beam Search Decoder for Natural Language Processing
by Jason Brownlee on January 5, 2018
Last Updated on August 7, 2019
https://machinelearningmastery.com/beam-search-decoder-natural-language-processing/
CAUTION: