Created
April 24, 2018 10:33
-
-
Save PhanDuc/795cf8e743040e5524226050514e5f5f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CaptionNet(nn.Module): | |
def __init__(self, n_tokens=n_tokens, emb_size=128, lstm_units=256, cnn_feature_size=2048): | |
""" A recurrent 'head' network for image captioning. See scheme above. """ | |
super(self.__class__, self).__init__() | |
# a layer that converts conv features to | |
self.cnn_to_h0 = nn.Linear(cnn_feature_size, lstm_units) | |
self.cnn_to_c0 = nn.Linear(cnn_feature_size, lstm_units) | |
# recurrent part, please create the layers as per scheme above. | |
# create embedding for input words. Use the parameters (e.g. emb_size). | |
#self.emb = <YOUR CODE> | |
self.emb = nn.Embedding(n_tokens, emb_size) | |
# lstm: create a recurrent core of your network. Use either LSTMCell or just LSTM. | |
# In the latter case (nn.LSTM), make sure batch_first=True | |
#self.lstm = <YOUR CODE> | |
self.lstm = nn.LSTM( | |
emb_size, | |
lstm_units, | |
cnn_feature_size, | |
batch_first = True, | |
) | |
# create logits: linear layer that takes lstm hidden state as input and computes one number per token | |
#self.logits = <YOUR CODE> | |
self.logits = nn.Linear(lstm_units, n_tokens) | |
def forward(self, image_vectors, captions_ix): | |
""" | |
Apply the network in training mode. | |
:param image_vectors: a Variable containing inception vectors. shape: [batch, cnn_feature_size] | |
:param captions_ix: a Variable containing captions as matrix. shape: [batch, word_i]. | |
padded with pad_ix | |
:returns: logits for next token at each tick, shape: [batch, word_i, n_tokens] | |
""" | |
initial_cell = self.cnn_to_c0(image_vectors) | |
initial_hid = self.cnn_to_h0(image_vectors) | |
#initial_cell = initial_cell.view(image_vectors.size()[1], initial_cell.size()[0], initial_cell.size()[1]) | |
#initial_hid = initial_hid.view(image_vectors.size()[1], initial_hid.size()[0], initial_hid.size()[1]) | |
# compute embeddings for captions_ix | |
#captions_emb = <YOUR CODE> | |
captions_emb = self.emb(captions_ix) | |
# apply recurrent layer to captions_emb. | |
# 1. initialize lstm state with initial_* from above | |
# 2. feed it with captions. Mind the dimension order in docstring | |
# 3. compute logits for next token probabilities | |
# Note: if you used nn.LSTM, you can just give it (initial_cell[None], initial_hid[None]) as second arg | |
# lstm_out should be lstm hidden state sequence of shape [batch, caption_length, lstm_units] | |
#lstm_out = <YOUR_CODE> | |
print(captions_emb.size()) | |
print(initial_cell.size()) | |
print(initial_hid.size()) | |
lstm_out, (hn, cn) = self.lstm(captions_emb, (initial_cell, initial_hid)) | |
# compute logits from lstm_out | |
#logits = <YOUR_CODE> | |
logits = self.logits(lstm_out) | |
return logits |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment