Created
August 16, 2023 12:45
-
-
Save pythonlessons/0d4635c6ee4b7316d95deffadfb81d0b to your computer and use it in GitHub Desktop.
transformer_attention
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class FeedForward(tf.keras.layers.Layer): | |
""" | |
A class that implements the feed-forward layer. | |
Methods: | |
call: Performs the forward pass of the layer. | |
Attributes: | |
seq (tf.keras.Sequential): The sequential layer that contains the feed-forward layers. It applies the two feed-forward layers and the dropout layer. | |
add (tf.keras.layers.Add): The Add layer. | |
layer_norm (tf.keras.layers.LayerNormalization): The LayerNormalization layer. | |
""" | |
def __init__(self, d_model: int, dff: int, dropout_rate: float=0.1): | |
""" | |
Constructor of the FeedForward layer. | |
Args: | |
d_model (int): The dimensionality of the model. | |
dff (int): The dimensionality of the feed-forward layer. | |
dropout_rate (float): The dropout rate. | |
""" | |
super().__init__() | |
self.seq = tf.keras.Sequential([ | |
tf.keras.layers.Dense(dff, activation='relu'), | |
tf.keras.layers.Dense(d_model), | |
tf.keras.layers.Dropout(dropout_rate) | |
]) | |
self.add = tf.keras.layers.Add() | |
self.layer_norm = tf.keras.layers.LayerNormalization() | |
def call(self, x: tf.Tensor) -> tf.Tensor: | |
""" | |
The call function that performs the feed-forward operation. | |
Args: | |
x (tf.Tensor): The input sequence of shape (batch_size, seq_length, d_model). | |
Returns: | |
tf.Tensor: The output sequence of shape (batch_size, seq_length, d_model). | |
""" | |
x = self.add([x, self.seq(x)]) | |
x = self.layer_norm(x) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment