Last active
July 23, 2020 12:06
-
-
Save williamFalcon/b0dc6d25b39e7da0d05e5713ef0a57af to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Blog post: | |
Taming LSTMs: Variable-sized mini-batches and why PyTorch is good for your health: | |
https://medium.com/@_willfalcon/taming-lstms-variable-sized-mini-batches-and-why-pytorch-is-good-for-your-health-61d35642972e | |
""" | |
def forward(self, X, X_lengths): | |
# reset the LSTM hidden state. Must be done before you run a new batch. Otherwise the LSTM will treat | |
# a new batch as a continuation of a sequence | |
self.hidden = self.init_hidden() | |
batch_size, seq_len, _ = X.size() | |
# --------------------- | |
# 1. embed the input | |
# Dim transformation: (batch_size, seq_len, 1) -> (batch_size, seq_len, embedding_dim) | |
X = self.word_embedding(X) | |
# --------------------- | |
# 2. Run through RNN | |
# TRICK 2 ******************************** | |
# Dim transformation: (batch_size, seq_len, embedding_dim) -> (batch_size, seq_len, nb_lstm_units) | |
# pack_padded_sequence so that padded items in the sequence won't be shown to the LSTM | |
X = torch.nn.utils.rnn.pack_padded_sequence(x, X_lengths, batch_first=True) | |
# now run through LSTM | |
X, self.hidden = self.lstm(X, self.hidden) | |
# undo the packing operation | |
X, _ = torch.nn.utils.rnn.pad_packed_sequence(X, batch_first=True) | |
# --------------------- | |
# 3. Project to tag space | |
# Dim transformation: (batch_size, seq_len, nb_lstm_units) -> (batch_size * seq_len, nb_lstm_units) | |
# this one is a bit tricky as well. First we need to reshape the data so it goes into the linear layer | |
X = X.contiguous() | |
X = X.view(-1, X.shape[2]) | |
# run through actual linear layer | |
X = self.hidden_to_tag(X) | |
# --------------------- | |
# 4. Create softmax activations bc we're doing classification | |
# Dim transformation: (batch_size * seq_len, nb_lstm_units) -> (batch_size, seq_len, nb_tags) | |
X = F.log_softmax(X, dim=1) | |
# I like to reshape for mental sanity so we're back to (batch_size, seq_len, nb_tags) | |
X = X.view(batch_size, seq_len, self.nb_tags) | |
Y_hat = X | |
return Y_hat |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment