Created
July 17, 2019 19:28
-
-
Save Hanrui-Wang/5aeca8904b7960d69cbb208f730a3420 to your computer and use it in GitHub Desktop.
how to construct encoder rnn and something about pack_padded_sequence and pad_packed_sequence
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class EncoderRNN(nn.Module): | |
| def __init__(self, hidden_size, embedding, n_layers=1, dropout=0): | |
| super(EncoderRNN, self).__init__() | |
| self.n_layers = n_layers | |
| self.hidden_size = hidden_size | |
| self.embedding = embedding | |
| # Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size' | |
| # because our input size is a word embedding with number of features == hidden_size | |
| self.gru = nn.GRU(hidden_size, hidden_size, n_layers, | |
| dropout=(0 if n_layers == 1 else dropout), bidirectional=True) | |
| def forward(self, input_seq, input_lengths, hidden=None): | |
| # Convert word indexes to embeddings | |
| embedded = self.embedding(input_seq) | |
| # Pack padded batch of sequences for RNN module | |
| packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths) | |
| # Forward pass through GRU | |
| outputs, hidden = self.gru(packed, hidden) | |
| # Unpack padding | |
| outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) | |
| # Sum bidirectional GRU outputs | |
| outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] | |
| # Return output and final hidden state | |
| return outputs, hidden | |
| # something useful: | |
| # We should take output[-1, :, :hidden_size] (normal RNN) and output[0, :, hidden_size:] (reverse RNN), concatenate them, and feed the result to the subsequent dense neural network. | |
| # The returned hidden states are the ones after consuming the whole sequence. They can be safely passed to the decoder. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment