Created
August 22, 2023 14:47
-
-
Save pythonlessons/d27bdc2a12d8aad251de93fbce99d9ed to your computer and use it in GitHub Desktop.
build_transformer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DecoderLayer(tf.keras.layers.Layer): | |
""" | |
A single layer of the Decoder. Usually there are multiple layers stacked on top of each other. | |
Methods: | |
call: Performs the forward pass of the layer. | |
Attributes: | |
causal_self_attention (CausalSelfAttention): The causal self-attention layer. | |
cross_attention (CrossAttention): The cross-attention layer. | |
ffn (FeedForward): The feed-forward layer. | |
""" | |
def __init__(self, d_model: int, num_heads: int, dff: int, dropout_rate: float=0.1): | |
""" | |
Constructor of the DecoderLayer. | |
Args: | |
d_model (int): The dimensionality of the model. | |
num_heads (int): The number of heads in the multi-head attention layer. | |
dff (int): The dimensionality of the feed-forward layer. | |
dropout_rate (float): The dropout rate. | |
""" | |
super(DecoderLayer, self).__init__() | |
self.causal_self_attention = CausalSelfAttention( | |
num_heads=num_heads, | |
key_dim=d_model, | |
dropout=dropout_rate) | |
self.cross_attention = CrossAttention( | |
num_heads=num_heads, | |
key_dim=d_model, | |
dropout=dropout_rate) | |
self.ffn = FeedForward(d_model, dff) | |
def call(self, x: tf.Tensor, context: tf.Tensor) -> tf.Tensor: | |
""" | |
The call function that performs the forward pass of the layer. | |
Args: | |
x (tf.Tensor): The input sequence of shape (batch_size, seq_length, d_model). x is usually the output of the previous decoder layer. | |
context (tf.Tensor): The context sequence of shape (batch_size, seq_length, d_model). Context is usually the output of the encoder. | |
""" | |
x = self.causal_self_attention(x=x) | |
x = self.cross_attention(x=x, context=context) | |
# Cache the last attention scores for plotting later | |
self.last_attn_scores = self.cross_attention.last_attn_scores | |
x = self.ffn(x) # Shape `(batch_size, seq_len, d_model)`. | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment