Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Created August 22, 2023 14:47
Show Gist options
  • Save pythonlessons/d27bdc2a12d8aad251de93fbce99d9ed to your computer and use it in GitHub Desktop.
Save pythonlessons/d27bdc2a12d8aad251de93fbce99d9ed to your computer and use it in GitHub Desktop.
build_transformer
class DecoderLayer(tf.keras.layers.Layer):
"""
A single layer of the Decoder. Usually there are multiple layers stacked on top of each other.
Methods:
call: Performs the forward pass of the layer.
Attributes:
causal_self_attention (CausalSelfAttention): The causal self-attention layer.
cross_attention (CrossAttention): The cross-attention layer.
ffn (FeedForward): The feed-forward layer.
"""
def __init__(self, d_model: int, num_heads: int, dff: int, dropout_rate: float=0.1):
"""
Constructor of the DecoderLayer.
Args:
d_model (int): The dimensionality of the model.
num_heads (int): The number of heads in the multi-head attention layer.
dff (int): The dimensionality of the feed-forward layer.
dropout_rate (float): The dropout rate.
"""
super(DecoderLayer, self).__init__()
self.causal_self_attention = CausalSelfAttention(
num_heads=num_heads,
key_dim=d_model,
dropout=dropout_rate)
self.cross_attention = CrossAttention(
num_heads=num_heads,
key_dim=d_model,
dropout=dropout_rate)
self.ffn = FeedForward(d_model, dff)
def call(self, x: tf.Tensor, context: tf.Tensor) -> tf.Tensor:
"""
The call function that performs the forward pass of the layer.
Args:
x (tf.Tensor): The input sequence of shape (batch_size, seq_length, d_model). x is usually the output of the previous decoder layer.
context (tf.Tensor): The context sequence of shape (batch_size, seq_length, d_model). Context is usually the output of the encoder.
"""
x = self.causal_self_attention(x=x)
x = self.cross_attention(x=x, context=context)
# Cache the last attention scores for plotting later
self.last_attn_scores = self.cross_attention.last_attn_scores
x = self.ffn(x) # Shape `(batch_size, seq_len, d_model)`.
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment