Skip to content

Instantly share code, notes, and snippets.

@ntakouris
Created June 26, 2021 07:43
Show Gist options
  • Save ntakouris/3e9630643c01d8a45356772355a08c75 to your computer and use it in GitHub Desktop.
Save ntakouris/3e9630643c01d8a45356772355a08c75 to your computer and use it in GitHub Desktop.
class SelfAttentionBlock(keras.Model):
def __init__(self,head_size, num_heads=1, ff_dim=None, dropout=0, name='SelfAttentionBlock', **kwargs):
super().__init__(name=name, **kwargs)
if ff_dim is None:
ff_dim = head_size
self.attention = SelfAttention(head_size, num_heads, dropout=dropout)
self.attention_dropout = keras.layers.Dropout(dropout)
self.attention_norm = keras.layers.LayerNormalization(epsilon=1e-6)
self.ff_conv1 = keras.layers.Conv1D(
filters=ff_dim, kernel_size=1, activation='relu')
# self.ff_conv2 at build()
self.ff_dropout = keras.layers.Dropout(dropout)
self.ff_norm = keras.layers.LayerNormalization(epsilon=1e-6)
def build(self, input_shape):
self.ff_conv2 = keras.layers.Conv1D(
filters=input_shape[-1], kernel_size=1)
def call(self, inputs, training, **kwargs):
x = self.attention_norm(inputs, **kwargs)
x = self.attention(x, **kwargs)
x = self.attention_dropout(x, training=training, **kwargs)
res = x + inputs
x = self.ff_norm(res, **kwargs)
x = self.ff_conv1(x, **kwargs)
x = self.ff_dropout(x, training=training, **kwargs)
x = self.ff_conv2(x, **kwargs)
x = self.ff_dropout(x, training=training, **kwargs)
return x + res
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment