Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Created August 16, 2023 12:45
Show Gist options
  • Select an option

  • Save pythonlessons/bbab4fb403955e0fa922d0002e7362d9 to your computer and use it in GitHub Desktop.

Select an option

Save pythonlessons/bbab4fb403955e0fa922d0002e7362d9 to your computer and use it in GitHub Desktop.
transformer_attention
class BaseAttention(tf.keras.layers.Layer):
"""
Base class for all attention layers. It contains the common functionality of all attention layers.
This layer contains a MultiHeadAttention layer, a LayerNormalization layer and an Add layer.
It is used as a base class for the GlobalSelfAttention, CausalSelfAttention and CrossAttention layers.
And it is not intended to be used directly.
Methods:
call: Performs the forward pass of the layer.
Attributes:
mha (tf.keras.layers.MultiHeadAttention): The MultiHeadAttention layer.
layernorm (tf.keras.layers.LayerNormalization): The LayerNormalization layer.
add (tf.keras.layers.Add): The Add layer.
"""
def __init__(self, **kwargs: dict):
""" Constructor of the BaseAttention layer.
Args:
**kwargs: Additional keyword arguments that are passed to the MultiHeadAttention layer, e. g.
num_heads (number of heads), key_dim (dimensionality of the key space), etc.
"""
super().__init__()
self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
self.layernorm = tf.keras.layers.LayerNormalization()
self.add = tf.keras.layers.Add()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment