Skip to content

Instantly share code, notes, and snippets.

@timotta
Last active July 29, 2021 12:32
Show Gist options
  • Save timotta/d33322008ffd07df54260905515f52e2 to your computer and use it in GitHub Desktop.
Save timotta/d33322008ffd07df54260905515f52e2 to your computer and use it in GitHub Desktop.
LSTM binary classification using pytorch and skorch, and pretrained gensin word2vec
# Needs gensim w2v_model trained
# EMBED_SIZE defined
# X should be a Matrix with examples on rows, and word indexes in sequence as columns
class LSTMClassification(nn.Module):
def __init__(self, embed_vector, hidden_dim=100, dropout=0.5):
super(LSTMClassification, self).__init__()
self.n_layers = 1
self.output_size = 1
self.hidden_dim = hidden_dim
self.embed_dim = embed_vector.shape[1]
self.embedding = nn.Embedding(len(embed_vector), EMBED_SIZE, padding_idx=len(embed_vector)-1)
self.lstm = nn.LSTM(self.embed_dim, hidden_dim, num_layers=self.n_layers, batch_first=True,
bidirectional=True)
self.linear = nn.Linear(hidden_dim*2, self.output_size)
self.drop = nn.Dropout(p=dropout)
def forward(self, x):
batch_size = x.size(0)
embeds = self.embedding(x.long())
hidden = self.init_hidden(batch_size)
lstm_out, _ = self.lstm(embeds, hidden)
linear_input = lstm_out[:, -1, :]
linear = self.linear(linear_input)
return self.drop(linear)
def init_hidden(self, batch_size):
is_cuda = torch.cuda.is_available()
if is_cuda:
device = torch.device("cuda")
else:
device = torch.device("cpu")
weight = next(self.parameters()).data
hidden = (
weight.new(self.n_layers*2, batch_size, self.hidden_dim).zero_().to(device),
weight.new(self.n_layers*2, batch_size, self.hidden_dim).zero_().to(device)
)
return hidden
net = NeuralNetBinaryClassifier(
LSTMClassification,
max_epochs=100,
batch_size=1024,
lr=0.001,
module__embed_vector=w2v_model.wv.vectors,
optimizer=torch.optim.Adam,
callbacks=[EarlyStopping(monitor='valid_acc', lower_is_better=False)],
)
torch.manual_seed(1982)
torch.cuda.manual_seed(1982)
np.random.seed(1982)
print("Starting...")
net.fit(X, y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment