Skip to content

Instantly share code, notes, and snippets.

@rdisipio
Last active December 17, 2020 22:19
Show Gist options
  • Save rdisipio/36a1d6637d61c04b678d102493e2bf89 to your computer and use it in GitHub Desktop.
Save rdisipio/36a1d6637d61c04b678d102493e2bf89 to your computer and use it in GitHub Desktop.
concat_size = inputs_dim + hidden_dim
clayer_in = torch.nn.Linear(concat_size, n_qubits)
VQC = [qml.qnn.TorchLayer(qlayer, weight_shapes) for _ in range(4)]
clayer_out = torch.nn.Linear(n_qubits, hidden_size)
hidden_seq = []
for t in range(seq_length):
# get features from the t-th element in seq, for all entries in the batch
x_t = x[:, t, :]
# Concatenate input and hidden state
v_t = torch.cat((h_t, x_t), dim=1)
# match qubit dimension
y_t = self.clayer_in(v_t)
f_t = torch.sigmoid(self.clayer_out(self.VQC[0](y_t))) # forget block
i_t = torch.sigmoid(self.clayer_out(self.VQC[1](y_t))) # input block
g_t = torch.tanh(self.clayer_out(self.VQC[2](y_t))) # update block
o_t = torch.sigmoid(self.clayer_out(self.VQC[3](y_t))) # output block
c_t = (f_t * c_t) + (i_t * g_t)
h_t = o_t * torch.tanh(c_t)
hidden_seq.append(h_t.unsqueeze(0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment