Skip to content

Instantly share code, notes, and snippets.

@rdisipio
Last active January 11, 2021 14:12
Show Gist options
  • Save rdisipio/f0326f32161537385a441ee253c4c4e5 to your computer and use it in GitHub Desktop.
Save rdisipio/f0326f32161537385a441ee253c4c4e5 to your computer and use it in GitHub Desktop.
class MultiHeadAttentionQuantum(MultiHeadAttentionBase):
def __init__(self,
embed_dim, num_heads,
n_qubits, n_qlayers=1, q_device='default.qubit'):
super(MultiHeadAttentionQuantum, self).__init__(embed_dim, num_heads)
# todo: add intermediate layer to "dress" quantum circuit
assert n_qubits == embed_dim, f"Number of qubits ({n_qubits}) does not match embedding dim ({embed_dim})"
self.dev = qml.device(q_device, wires=n_qubits)
weight_shapes = {"weights": (n_qlayers, n_qubits)}
print(f"weight_shapes = (n_qlayers, n_qubits) = ({n_qlayers}, {n_qubits})")
def _circuit(inputs, weights):
qml.templates.AngleEmbedding(inputs, wires=range(n_qubits))
qml.templates.BasicEntanglerLayers(weights, wires=range(n_qubits))
return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]
self.qlayer = qml.QNode(_circuit, self.dev, interface="tf")
self.wq = qml.qnn.KerasLayer(self.qlayer, weight_shapes, output_dim=n_qubits)
self.wk = qml.qnn.KerasLayer(self.qlayer, weight_shapes, output_dim=n_qubits)
self.wv = qml.qnn.KerasLayer(self.qlayer, weight_shapes, output_dim=n_qubits)
self.dense = qml.qnn.KerasLayer(self.qlayer, weight_shapes, output_dim=n_qubits)
def apply_dense_layers(self, v, k, q):
batch_size, seq_len, _ = tf.shape(q)
q = [self.wq(q[:, t, :]) for t in range(seq_len)] # (seq_len, batch_size, embed_dim)
k = [self.wk(k[:, t, :]) for t in range(seq_len)] # (seq_len, batch_size, embed_dim)
v = [self.wv(v[:, t, :]) for t in range(seq_len)] # (seq_len, batch_size, embed_dim)
q = tf.convert_to_tensor(q)
k = tf.convert_to_tensor(k)
v = tf.convert_to_tensor(v)
q = tf.transpose(q, perm=[1, 0, 2]) # (batch_size, seq_len, embed_dim)
k = tf.transpose(k, perm=[1, 0, 2]) # (batch_size, seq_len, embed_dim)
v = tf.transpose(v, perm=[1, 0, 2]) # (batch_size, seq_len, embed_dim)
return v, k, q
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment