Skip to content

Instantly share code, notes, and snippets.

@sumedhpendurkar
Last active March 13, 2019 17:45
Show Gist options
  • Save sumedhpendurkar/ba9b7177b5ceddbbfd2890663221dc85 to your computer and use it in GitHub Desktop.
Save sumedhpendurkar/ba9b7177b5ceddbbfd2890663221dc85 to your computer and use it in GitHub Desktop.
Encoder Class for seq-2-seq modelling in PyTorch
class Encoder(nn.Module):
def __init__(self, input_size, hidden_size, bidirectional = True):
super(Encoder, self).__init__()
self.hidden_size = hidden_size
self.input_size = input_size
self.bidirectional = bidirectional
self.lstm = nn.LSTM(input_size, hidden_size, bidirectional = bidirectional)
def forward(self, inputs, hidden):
output, hidden = self.lstm(inputs.view(1, 1, self.input_size), hidden)
return output, hidden
def init_hidden(self):
return (torch.zeros(1 + int(self.bidirectional), 1, self.hidden_size),
torch.zeros(1 + int(self.bidirectional), 1, self.hidden_size))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment