Skip to content

Instantly share code, notes, and snippets.

@MLWhiz
Created March 9, 2019 15:01
Show Gist options
  • Select an option

  • Save MLWhiz/3f3324457bdfb3f3ab97e52e515a1b39 to your computer and use it in GitHub Desktop.

Select an option

Save MLWhiz/3f3324457bdfb3f3ab97e52e515a1b39 to your computer and use it in GitHub Desktop.
class BiLSTM(nn.Module):
def __init__(self):
super(BiLSTM, self).__init__()
self.hidden_size = 64
drp = 0.1
self.embedding = nn.Embedding(max_features, embed_size)
self.embedding.weight = nn.Parameter(torch.tensor(embedding_matrix, dtype=torch.float32))
self.embedding.weight.requires_grad = False
self.lstm = nn.LSTM(embed_size, self.hidden_size, bidirectional=True, batch_first=True)
self.linear = nn.Linear(self.hidden_size*4 , 64)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(drp)
self.out = nn.Linear(64, 1)
def forward(self, x):
h_embedding = self.embedding(x)
h_embedding = torch.squeeze(torch.unsqueeze(h_embedding, 0))
h_lstm, _ = self.lstm(h_embedding)
avg_pool = torch.mean(h_lstm, 1)
max_pool, _ = torch.max(h_lstm, 1)
conc = torch.cat(( avg_pool, max_pool), 1)
conc = self.relu(self.linear(conc))
conc = self.dropout(conc)
out = self.out(conc)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment