Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Created August 22, 2023 14:47
Show Gist options
  • Save pythonlessons/c4ad039560137d112e7ca1fd811cb95d to your computer and use it in GitHub Desktop.
Save pythonlessons/c4ad039560137d112e7ca1fd811cb95d to your computer and use it in GitHub Desktop.
build_transformer
class EncoderLayer(tf.keras.layers.Layer):
"""
A single layer of the Encoder. Usually there are multiple layers stacked on top of each other.
Methods:
call: Performs the forward pass of the layer.
Attributes:
self_attention (GlobalSelfAttention): The global self-attention layer.
ffn (FeedForward): The feed-forward layer.
"""
def __init__(self, d_model: int, num_heads: int, dff: int, dropout_rate: float=0.1):
"""
Constructor of the EncoderLayer.
Args:
d_model (int): The dimensionality of the model.
num_heads (int): The number of heads in the multi-head attention layer.
dff (int): The dimensionality of the feed-forward layer.
dropout_rate (float): The dropout rate.
"""
super().__init__()
self.self_attention = GlobalSelfAttention(
num_heads=num_heads,
key_dim=d_model,
dropout=dropout_rate
)
self.ffn = FeedForward(d_model, dff)
def call(self, x: tf.Tensor) -> tf.Tensor:
"""
The call function that performs the forward pass of the layer.
Args:
x (tf.Tensor): The input sequence of shape (batch_size, seq_length, d_model).
Returns:
tf.Tensor: The output sequence of shape (batch_size, seq_length, d_model).
"""
x = self.self_attention(x)
x = self.ffn(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment