Skip to content

Instantly share code, notes, and snippets.

@rdisipio
Created January 10, 2021 15:35
Show Gist options
  • Save rdisipio/4ccd1c0cf26a5c32fcbdd1f16c44a1f8 to your computer and use it in GitHub Desktop.
Save rdisipio/4ccd1c0cf26a5c32fcbdd1f16c44a1f8 to your computer and use it in GitHub Desktop.
class MultiHeadAttentionClassical(MultiHeadAttentionBase):
def __init__(self, embed_dim, num_heads):
super(MultiHeadAttentionClassical, self).__init__(embed_dim, num_heads)
self.wq = tf.keras.layers.Dense(embed_dim)
self.wk = tf.keras.layers.Dense(embed_dim)
self.wv = tf.keras.layers.Dense(embed_dim)
self.dense = tf.keras.layers.Dense(embed_dim)
def apply_dense_layers(self, v, k, q):
q = self.wq(q) # (batch_size, seq_len, embed_dim)
k = self.wk(k) # (batch_size, seq_len, embed_dim)
v = self.wv(v) # (batch_size, seq_len, embed_dim)
return v, k, q
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment