Skip to content

Instantly share code, notes, and snippets.

@razhangwei
Last active January 2, 2025 00:36
Show Gist options
  • Save razhangwei/8d68741b15430c4a37a26490178f35ff to your computer and use it in GitHub Desktop.
Save razhangwei/8d68741b15430c4a37a26490178f35ff to your computer and use it in GitHub Desktop.
#pytorch RoPE
import math
import torch
import torch.nn as nn
class RotaryPositionEmbedding(nn.Module):
"""
Implements Rotary Position Embedding (RoPE) as a PyTorch module.
Args:
dim (int): Dimension of the embedding (must be even)
max_position (int): Maximum sequence length to pre-compute
base (float, optional): Base for frequency computation. Defaults to 10000.0
scale (float, optional): Scaling factor for theta. Defaults to 1.0
"""
def __init__(self, dim: int, max_position: int, base: float = 10000.0, scale: float = 1.0):
super().__init__()
if dim % 2 != 0:
raise ValueError(f"Dimension {dim} must be even")
self.dim = dim
self.max_position = max_position
self.base = base
self.scale = scale
# Pre-compute position encodings
self.register_buffer(
"inv_freq",
1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
)
self._generate_position_embeddings()
def _generate_position_embeddings(self):
"""Generate sinusoidal position embeddings."""
# [max_position, dim/2]
position = torch.arange(self.max_position).float()
position_mat = torch.outer(position, self.inv_freq)
# Scale theta if needed
position_mat = position_mat * self.scale
# [max_position, dim/2]
self.register_buffer("sin_cached", torch.sin(position_mat))
self.register_buffer("cos_cached", torch.cos(position_mat))
def forward(self, x: torch.Tensor, position_ids: torch.Tensor = None) -> torch.Tensor:
"""
Apply rotary position embeddings to input tensor.
Args:
x (torch.Tensor): Input tensor of shape [..., seq_len, dim]
position_ids (torch.Tensor, optional): Position indices of shape [..., seq_len].
If None, uses sequential positions.
Returns:
torch.Tensor: Rotary position embedded tensor of same shape as input
"""
seq_len = x.size(-2)
if position_ids is None:
position_ids = torch.arange(seq_len, device=x.device)
position_ids = position_ids.unsqueeze(0).expand(x.size(0), -1)
if position_ids.max() >= self.max_position:
self.max_position = position_ids.max() + 1
self._generate_position_embeddings()
# Get sin and cos for current positions
sin = self.sin_cached[position_ids] # [..., seq_len, dim/2]
cos = self.cos_cached[position_ids] # [..., seq_len, dim/2]
# Reshape input for rotation
x_reshape = x.view(*x.shape[:-1], -1, 2)
# Extract even and odd dimensions
x1 = x_reshape[..., 0::2] # even indices
x2 = x_reshape[..., 1::2] # odd indices
# Rotate pairs
rotated = torch.stack([
x1 * cos - x2 * sin,
x1 * sin + x2 * cos
], dim=-1)
# Restore original shape
return rotated.flatten(-2)
def _extend_position_embeddings(self, new_max_position: int):
"""Extend position embeddings to handle longer sequences."""
self.max_position = new_max_position
self._generate_position_embeddings()
# Example usage
if __name__ == "__main__":
# Parameters
batch_size = 2
seq_len = 8
n_heads = 4
head_dim = 64
# Create module
rope = RotaryPositionEmbedding(
dim=head_dim,
max_position=16,
base=10000.0
)
# Create sample input
x = torch.randn(batch_size, seq_len, n_heads, head_dim)
# Apply RoPE
rotated_x = rope(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {rotated_x.shape}")
# Test with custom position ids
position_ids = torch.tensor([[0, 2, 4, 6, 8, 10, 12, 14],
[1, 3, 5, 7, 9, 11, 13, 15]])
rotated_x_custom = rope(x, position_ids)
print(f"Custom position output shape: {rotated_x_custom.shape}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment