Skip to content

Instantly share code, notes, and snippets.

@MLWhiz
Last active September 7, 2020 16:14
Show Gist options
  • Save MLWhiz/ff33936a251995ee5892999a08dee271 to your computer and use it in GitHub Desktop.
Save MLWhiz/ff33936a251995ee5892999a08dee271 to your computer and use it in GitHub Desktop.
class BiLSTM(nn.Module):
def __init__(self):
super().__init__()
self.hidden_size = 64
drp = 0.1
max_features, embed_size = 10000,300
self.embedding = nn.Embedding(max_features, embed_size)
self.lstm = nn.LSTM(embed_size, self.hidden_size, bidirectional=True, batch_first=True)
self.linear = nn.Linear(self.hidden_size*4 , 64)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(drp)
self.out = nn.Linear(64, 1)
def forward(self, x):
h_embedding = self.embedding(x)
h_embedding = torch.squeeze(torch.unsqueeze(h_embedding, 0))
h_lstm, _ = self.lstm(h_embedding)
avg_pool = torch.mean(h_lstm, 1)
max_pool, _ = torch.max(h_lstm, 1)
conc = torch.cat(( avg_pool, max_pool), 1)
conc = self.relu(self.linear(conc))
conc = self.dropout(conc)
out = self.out(conc)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment