Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Created August 16, 2023 12:45
Show Gist options
  • Save pythonlessons/824bcb4590d6da74a6f44ee007797c54 to your computer and use it in GitHub Desktop.
Save pythonlessons/824bcb4590d6da74a6f44ee007797c54 to your computer and use it in GitHub Desktop.
transformer_attention
class CausalSelfAttention(BaseAttention):
"""
Call self attention on the input sequence, ensuring that each position in the
output depends only on previous positions (i.e. a causal model).
Methods:
call: Performs the forward pass of the layer.
Attributes:
mha (tf.keras.layers.MultiHeadAttention): The MultiHeadAttention layer.
layernorm (tf.keras.layers.LayerNormalization): The LayerNormalization layer.
add (tf.keras.layers.Add): The Add layer.
"""
def call(self, x: tf.Tensor) -> tf.Tensor:
"""
The call function that performs the causal self-attention operation.
Args:
x (tf.Tensor): The input sequence of shape (batch_size, seq_length, d_model).
Returns:
tf.Tensor: The output sequence of shape (batch_size, seq_length, d_model).
"""
attn_output = self.mha(query=x, value=x, key=x, use_causal_mask = True)
x = self.add([x, attn_output])
x = self.layernorm(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment