This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class CausalSelfAttention(BaseAttention): | |
| """ | |
| Call self attention on the input sequence, ensuring that each position in the | |
| output depends only on previous positions (i.e. a causal model). | |
| Methods: | |
| call: Performs the forward pass of the layer. | |
| Attributes: | |
| mha (tf.keras.layers.MultiHeadAttention): The MultiHeadAttention layer. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| encoder_vocab_size = 1000 | |
| d_model = 512 | |
| encoder_embedding_layer = PositionalEmbedding(vocab_size, d_model) | |
| random_encoder_input = np.random.randint(0, encoder_vocab_size, size=(1, 100)) | |
| encoder_embeddings = encoder_embedding_layer(random_encoder_input) | |
| print("encoder_embeddings shape", encoder_embeddings.shape) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class GlobalSelfAttention(BaseAttention): | |
| """ | |
| A class that implements the global self-attention layer by inheriting from the BaseAttention class. | |
| This layer is used to process a single sequence and attends to all the tokens in the sequence. | |
| Methods: | |
| call: Performs the forward pass of the layer. | |
| Attributes: | |
| mha (tf.keras.layers.MultiHeadAttention): The MultiHeadAttention layer. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| encoder_vocab_size = 1000 | |
| decoder_vocab_size = 1100 | |
| d_model = 512 | |
| encoder_embedding_layer = PositionalEmbedding(vocab_size, d_model) | |
| decoder_embedding_layer = PositionalEmbedding(vocab_size, d_model) | |
| random_encoder_input = np.random.randint(0, encoder_vocab_size, size=(1, 100)) | |
| random_decoder_input = np.random.randint(0, decoder_vocab_size, size=(1, 110)) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class CrossAttention(BaseAttention): | |
| """ | |
| A class that implements the cross-attention layer by inheriting from the BaseAttention class. | |
| This layer is used to process two different sequences and attends to the context sequence while processing the query sequence. | |
| Methods: | |
| call: Performs the forward pass of the layer. | |
| Attributes: | |
| mha (tf.keras.layers.MultiHeadAttention): The MultiHeadAttention layer. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class BaseAttention(tf.keras.layers.Layer): | |
| """ | |
| Base class for all attention layers. It contains the common functionality of all attention layers. | |
| This layer contains a MultiHeadAttention layer, a LayerNormalization layer and an Add layer. | |
| It is used as a base class for the GlobalSelfAttention, CausalSelfAttention and CrossAttention layers. | |
| And it is not intended to be used directly. | |
| Methods: | |
| call: Performs the forward pass of the layer. | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| random_input shape (1, 100) | |
| PositionalEmbedding output (1, 100, 512) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| vocab_size = 1000 | |
| d_model = 512 | |
| embedding_layer = PositionalEmbedding(vocab_size, d_model) | |
| random_input = np.random.randint(0, vocab_size, size=(1, 100)) | |
| output = embedding_layer(random_input) | |
| print("random_input shape", random_input.shape) | |
| print("PositionalEmbedding output", output.shape) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class PositionalEmbedding(tf.keras.layers.Layer): | |
| """ | |
| A positional embedding layer combines the input embedding with a positional encoding that helps the Transformer | |
| to understand the relative position of the input tokens. This layer takes the input of tokens and converts them | |
| into sequence of embeddings vector. Then, it adds the positional encoding to the embeddings. | |
| Methods: | |
| compute_mask: Computes the mask to be applied to the embeddings. | |
| call: Performs the forward pass of the layer. | |
| """ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| pos_encoding /= tf.norm(pos_encoding, axis=1, keepdims=True) | |
| p = pos_encoding[1000] | |
| dots = tf.einsum('pd,d->p', pos_encoding, p).numpy() | |
| plt.subplot(2, 1, 1) | |
| plt.plot(dots) | |
| plt.ylim([0, 1]) | |
| plt.plot([950, 950, float('nan'), 1050, 1050], [0, 1, float('nan'), 0, 1], color='k', label='Zoom') | |
| plt.legend() | |