Skip to content

Instantly share code, notes, and snippets.

@ariG23498
Last active March 16, 2022 07:01
Show Gist options
  • Save ariG23498/cba7faf111ed9bdcf971c160d313a5af to your computer and use it in GitHub Desktop.
Save ariG23498/cba7faf111ed9bdcf971c160d313a5af to your computer and use it in GitHub Desktop.
Just a transformer block impelemented in TensorFlow Keras
class TransformerBlock(layers.Layer):
"""A generic Transformer block with MHSA and MLP layers.
Args:
config: The configuration of the architecture.
"""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.layer_norm1 = layers.LayerNormalization(
epsilon=self.config.eps
)
self.mhsa = layers.MultiHeadAttention(
num_heads=self.config.num_heads,
key_dim=self.config.projection_dim,
dropout=self.config.dropout_rate,
)
self.dropout = layers.Dropout(self.config.dropout_rate)
self.residual_connection = layers.Add()
self.layer_norm2 = layers.LayerNormalization(
epsilon=self.config.eps
)
self.mlp = keras.Sequential([
layers.Dense(
units=4 * self.config.projection_dim,
activation=tf.nn.gelu,
),
layers.Dropout(self.config.dropout_rate),
layers.Dense(
units=self.config.projection_dim,
),
layers.Dropout(self.config.dropout_rate),
])
def get_config(self):
config = super().get_config()
config.update(self.config)
return config
def call(self, inputs):
x1 = self.layer_norm1(inputs)
attention_outputs = self.mhsa(
query=x1,
key=x1,
value=x1,
)
attention_outputs = self.dropout(attention_outputs)
x2 = self.residual_connection([attention_outputs, inputs])
x3 = self.layer_norm2(x2)
x4 = self.mlp(x3)
outputs = self.residual_connection([x4, x2])
return outputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment