Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Created August 10, 2023 13:19
Show Gist options
  • Select an option

  • Save pythonlessons/6888520cbe9ac183a8fdd9614f9e0fec to your computer and use it in GitHub Desktop.

Select an option

Save pythonlessons/6888520cbe9ac183a8fdd9614f9e0fec to your computer and use it in GitHub Desktop.
transformers_introduction
class PositionalEmbedding(tf.keras.layers.Layer):
"""
A positional embedding layer combines the input embedding with a positional encoding that helps the Transformer
to understand the relative position of the input tokens. This layer takes the input of tokens and converts them
into sequence of embeddings vector. Then, it adds the positional encoding to the embeddings.
Methods:
compute_mask: Computes the mask to be applied to the embeddings.
call: Performs the forward pass of the layer.
"""
def __init__(self, vocab_size: int, d_model: int, embedding: tf.keras.layers.Embedding=None):
""" Constructor of the PositionalEmbedding layer.
Args:
vocab_size (int): The size of the vocabulary. I. e. the number of unique tokens in the input sequence.
d_model (int): The dimensionality of the embedding vector.
embedding (tf.keras.layers.Embedding): The custom embedding layer. If None, a default embedding layer will be created.
"""
super().__init__()
self.d_model = d_model
self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True) if embedding is None else embedding
self.pos_encoding = positional_encoding(length=2048, depth=d_model)
def compute_mask(self, *args, **kwargs):
""" Computes the mask to be applied to the embeddings.
"""
return self.embedding.compute_mask(*args, **kwargs)
def call(self, x: tf.Tensor) -> tf.Tensor:
""" Performs the forward pass of the layer.
Args:
x (tf.Tensor): The input tensor of shape (batch_size, seq_length).
Returns:
tf.Tensor: The output sequence of embedding vectors with added positional information. The shape is
(batch_size, seq_length, d_model).
"""
x = self.embedding(x)
length = tf.shape(x)[1]
# This factor sets the relative scale of the embedding and positonal_encoding.
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x = x + self.pos_encoding[tf.newaxis, :length, :]
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment