Skip to content

Instantly share code, notes, and snippets.

@Sinjhin
Created November 28, 2024 19:36
Show Gist options
  • Save Sinjhin/7e93e6f0242a8be43bc199e99df0ee60 to your computer and use it in GitHub Desktop.
Save Sinjhin/7e93e6f0242a8be43bc199e99df0ee60 to your computer and use it in GitHub Desktop.
Positional Encoder for Transformer
import torch
import matplotlib.pyplot as plt
def visualize_positional_encoding(seq_length=83, d_model=32):
pos = torch.arange(0, seq_length).unsqueeze(-1).float()
pos_encoding = torch.zeros(seq_length, d_model)
pos_encoding[:, 0::2] = torch.sin(pos / 10000 ** (torch.arange(0, d_model, 2) / d_model))
pos_encoding[:, 1::2] = torch.cos(pos / 10000 ** (torch.arange(1, d_model, 2) / d_model))
encoding_matrix = pos_encoding.numpy()
plt.figure(figsize=(10, 8))
plt.imshow(encoding_matrix, cmap='viridis', aspect='auto')
plt.colorbar()
plt.xlabel('Embedding Dimension')
plt.ylabel('Sequence Position')
plt.title('Positional Encoding Pattern')
print("Position 0 (Height):", encoding_matrix[0, :5])
print("Position 1 (Width):", encoding_matrix[1, :5])
print("Position 2 (First grid value):", encoding_matrix[2, :5])
plt.show()
visualize_positional_encoding()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment