Skip to content

Instantly share code, notes, and snippets.

@rdisipio
Created December 17, 2020 22:41
Show Gist options
  • Save rdisipio/6393e22c5b8efeee3adf9645fde06ba5 to your computer and use it in GitHub Desktop.
Save rdisipio/6393e22c5b8efeee3adf9645fde06ba5 to your computer and use it in GitHub Desktop.
QLSTM POS Tagger
class LSTMTagger(nn.Module):
def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size, n_qubits=0):
super(LSTMTagger, self).__init__()
self.hidden_dim = hidden_dim
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
# The LSTM takes word embeddings as inputs, and outputs hidden states
# with dimensionality hidden_dim.
if n_qubits > 0:
print("Tagger will use Quantum LSTM")
self.lstm = QLSTM(embedding_dim, hidden_dim, n_qubits=n_qubits)
else:
print("Tagger will use Classical LSTM")
self.lstm = nn.LSTM(embedding_dim, hidden_dim)
# The linear layer that maps from hidden state space to tag space
self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
def forward(self, sentence):
embeds = self.word_embeddings(sentence)
lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
tag_logits = self.hidden2tag(lstm_out.view(len(sentence), -1))
tag_scores = F.log_softmax(tag_logits, dim=1)
return tag_scores
loss_function = nn.NLLLoss(). # the output is a log_softmax!
optimizer = optim.SGD(model.parameters(), lr=0.1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment