Last active
January 2, 2025 00:36
-
-
Save razhangwei/8d68741b15430c4a37a26490178f35ff to your computer and use it in GitHub Desktop.
#pytorch RoPE
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
import torch.nn as nn | |
class RotaryPositionEmbedding(nn.Module): | |
""" | |
Implements Rotary Position Embedding (RoPE) as a PyTorch module. | |
Args: | |
dim (int): Dimension of the embedding (must be even) | |
max_position (int): Maximum sequence length to pre-compute | |
base (float, optional): Base for frequency computation. Defaults to 10000.0 | |
scale (float, optional): Scaling factor for theta. Defaults to 1.0 | |
""" | |
def __init__(self, dim: int, max_position: int, base: float = 10000.0, scale: float = 1.0): | |
super().__init__() | |
if dim % 2 != 0: | |
raise ValueError(f"Dimension {dim} must be even") | |
self.dim = dim | |
self.max_position = max_position | |
self.base = base | |
self.scale = scale | |
# Pre-compute position encodings | |
self.register_buffer( | |
"inv_freq", | |
1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) | |
) | |
self._generate_position_embeddings() | |
def _generate_position_embeddings(self): | |
"""Generate sinusoidal position embeddings.""" | |
# [max_position, dim/2] | |
position = torch.arange(self.max_position).float() | |
position_mat = torch.outer(position, self.inv_freq) | |
# Scale theta if needed | |
position_mat = position_mat * self.scale | |
# [max_position, dim/2] | |
self.register_buffer("sin_cached", torch.sin(position_mat)) | |
self.register_buffer("cos_cached", torch.cos(position_mat)) | |
def forward(self, x: torch.Tensor, position_ids: torch.Tensor = None) -> torch.Tensor: | |
""" | |
Apply rotary position embeddings to input tensor. | |
Args: | |
x (torch.Tensor): Input tensor of shape [..., seq_len, dim] | |
position_ids (torch.Tensor, optional): Position indices of shape [..., seq_len]. | |
If None, uses sequential positions. | |
Returns: | |
torch.Tensor: Rotary position embedded tensor of same shape as input | |
""" | |
seq_len = x.size(-2) | |
if position_ids is None: | |
position_ids = torch.arange(seq_len, device=x.device) | |
position_ids = position_ids.unsqueeze(0).expand(x.size(0), -1) | |
if position_ids.max() >= self.max_position: | |
self.max_position = position_ids.max() + 1 | |
self._generate_position_embeddings() | |
# Get sin and cos for current positions | |
sin = self.sin_cached[position_ids] # [..., seq_len, dim/2] | |
cos = self.cos_cached[position_ids] # [..., seq_len, dim/2] | |
# Reshape input for rotation | |
x_reshape = x.view(*x.shape[:-1], -1, 2) | |
# Extract even and odd dimensions | |
x1 = x_reshape[..., 0::2] # even indices | |
x2 = x_reshape[..., 1::2] # odd indices | |
# Rotate pairs | |
rotated = torch.stack([ | |
x1 * cos - x2 * sin, | |
x1 * sin + x2 * cos | |
], dim=-1) | |
# Restore original shape | |
return rotated.flatten(-2) | |
def _extend_position_embeddings(self, new_max_position: int): | |
"""Extend position embeddings to handle longer sequences.""" | |
self.max_position = new_max_position | |
self._generate_position_embeddings() | |
# Example usage | |
if __name__ == "__main__": | |
# Parameters | |
batch_size = 2 | |
seq_len = 8 | |
n_heads = 4 | |
head_dim = 64 | |
# Create module | |
rope = RotaryPositionEmbedding( | |
dim=head_dim, | |
max_position=16, | |
base=10000.0 | |
) | |
# Create sample input | |
x = torch.randn(batch_size, seq_len, n_heads, head_dim) | |
# Apply RoPE | |
rotated_x = rope(x) | |
print(f"Input shape: {x.shape}") | |
print(f"Output shape: {rotated_x.shape}") | |
# Test with custom position ids | |
position_ids = torch.tensor([[0, 2, 4, 6, 8, 10, 12, 14], | |
[1, 3, 5, 7, 9, 11, 13, 15]]) | |
rotated_x_custom = rope(x, position_ids) | |
print(f"Custom position output shape: {rotated_x_custom.shape}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment