Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Created August 16, 2023 12:45
Show Gist options
  • Save pythonlessons/0d4635c6ee4b7316d95deffadfb81d0b to your computer and use it in GitHub Desktop.
Save pythonlessons/0d4635c6ee4b7316d95deffadfb81d0b to your computer and use it in GitHub Desktop.
transformer_attention
class FeedForward(tf.keras.layers.Layer):
"""
A class that implements the feed-forward layer.
Methods:
call: Performs the forward pass of the layer.
Attributes:
seq (tf.keras.Sequential): The sequential layer that contains the feed-forward layers. It applies the two feed-forward layers and the dropout layer.
add (tf.keras.layers.Add): The Add layer.
layer_norm (tf.keras.layers.LayerNormalization): The LayerNormalization layer.
"""
def __init__(self, d_model: int, dff: int, dropout_rate: float=0.1):
"""
Constructor of the FeedForward layer.
Args:
d_model (int): The dimensionality of the model.
dff (int): The dimensionality of the feed-forward layer.
dropout_rate (float): The dropout rate.
"""
super().__init__()
self.seq = tf.keras.Sequential([
tf.keras.layers.Dense(dff, activation='relu'),
tf.keras.layers.Dense(d_model),
tf.keras.layers.Dropout(dropout_rate)
])
self.add = tf.keras.layers.Add()
self.layer_norm = tf.keras.layers.LayerNormalization()
def call(self, x: tf.Tensor) -> tf.Tensor:
"""
The call function that performs the feed-forward operation.
Args:
x (tf.Tensor): The input sequence of shape (batch_size, seq_length, d_model).
Returns:
tf.Tensor: The output sequence of shape (batch_size, seq_length, d_model).
"""
x = self.add([x, self.seq(x)])
x = self.layer_norm(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment