RoPE encodes positional information by rotating embedding vectors in a way that:
- Preserves relative positions: The dot product between tokens depends on their relative distance
- Uses rotation: Each position gets rotated by an angle proportional to its position
- Works in pairs: Dimensions are grouped in pairs and rotated together
For a position m and dimension pair (2i, 2i+1), RoPE applies a 2D rotation:
Where θ_i is the rotation frequency for dimension pair i:
- base = 10000 (typically)
- d = total embedding dimension
- i = dimension pair index (0, 1, 2, ...)
- Different frequencies for different dimensions: Lower dimensions rotate faster (smaller θ), higher dimensions rotate slower (larger θ)
-
Complex number representation: The rotation can be elegantly expressed as multiplying by
$e^{im\theta_i}$ -
Relative position encoding: The angle between positions m₁ and m₂ is
$(m_1 - m_2)\theta_i$
The exponential decay ensures each dimension pair operates at a different "wavelength" for encoding position!
Let me break down all three functions with simple, concrete examples.
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cisPrecomputes the rotation angles for all positions and all dimension pairs.
torch.arange(0, 4, 2) # [0, 2]
[: (4 // 2)] # [0, 2] (first 2 elements)
.float() / 4 # [0.0, 0.5]
theta ** (...) # [10000^0.0, 10000^0.5] = [1.0, 100.0]
freqs = 1.0 / (...) # [1.0, 0.01]Result: freqs = [1.0, 0.01] (2 frequency values for 2 dimension pairs)
t = torch.arange(3) # [0, 1, 2] (positions 0, 1, 2)freqs = torch.outer(t, freqs)
# Outer product: each position × each frequency
[1.0, 0.01]
[0] [0.0, 0.0 ]
[1] [1.0, 0.01]
[2] [2.0, 0.02]Meaning:
- Position 0: rotations [0.0, 0.0]
- Position 1: rotations [1.0, 0.01]
- Position 2: rotations [2.0, 0.02]
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
# polar(magnitude, angle) = magnitude * e^(i*angle)
# magnitude = 1.0 for allResult: Complex numbers representing rotations
Position 0: [e^(i*0.0), e^(i*0.0) ] = [1.0+0i, 1.0+0i ]
Position 1: [e^(i*1.0), e^(i*0.01)] = [0.54+0.84i, 0.99+0.01i]
Position 2: [e^(i*2.0), e^(i*0.02)] = [-0.42+0.91i, 0.99+0.02i]
Shape: [3, 2] → (3 positions, 2 dimension pairs)
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)Reshape freqs_cis to broadcast correctly with the input tensor.
Suppose:
xhas shape[batch, seq_len, num_heads, head_dim]=[2, 3, 4, 4]- After converting to complex:
x_complexhas shape[2, 3, 4, 2](last dim halved) freqs_cishas shape[3, 2](seq_len=3, head_dim/2=2)
ndim = 4 # x_complex has 4 dimensions
# Check: freqs_cis.shape should be (x.shape[1], x.shape[-1])
# freqs_cis.shape = (3, 2) ✓
# x.shape[1] = 3 (seq_len), x.shape[-1] = 2 (head_dim/2) ✓
# Build new shape
shape = []
for i, d in enumerate([2, 3, 4, 2]):
if i == 1 or i == 3: # Keep dimensions 1 and 3
shape.append(d)
else:
shape.append(1) # Add singleton dimensions
shape = [1, 3, 1, 2]Result: freqs_cis.view(1, 3, 1, 2)
Why? This allows broadcasting:
x_complex: [2, 3, 4, 2]
freqs_cis: [1, 3, 1, 2]
↓ ↓ ↓ ↓
After broadcast:[2, 3, 4, 2] (multiplication works!)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)Apply rotary embeddings to query and key tensors.
Let's trace through with concrete numbers!
xq.shape = [1, 2, 1, 4] # [batch, seq_len, heads, head_dim]
xk.shape = [1, 2, 1, 4]
freqs_cis.shape = [2, 2] # [seq_len, head_dim/2]
# Suppose xq looks like:
xq = [
[ # batch 0
[[1.0, 2.0, 3.0, 4.0]], # position 0, head 0
[[5.0, 6.0, 7.0, 8.0]] # position 1, head 0
]
]xq.float().reshape(*xq.shape[:-1], -1, 2)
# [1, 2, 1, 4] → [1, 2, 1, 2, 2]
# batch, seq, heads, pairs, (real,imag)
[
[
[[[1.0, 2.0], [3.0, 4.0]]], # pos 0: pairs (1,2) and (3,4)
[[[5.0, 6.0], [7.0, 8.0]]] # pos 1: pairs (5,6) and (7,8)
]
]xq_ = torch.view_as_complex(...)
# Shape: [1, 2, 1, 2] (complex numbers)
[
[
[[1.0+2.0i, 3.0+4.0i]], # pos 0
[[5.0+6.0i, 7.0+8.0i]] # pos 1
]
]freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
# [2, 2] → [1, 2, 1, 2]
# Suppose freqs_cis contains:
[
[
[[1.0+0.0i, 1.0+0.0i]], # pos 0: no rotation
[[0.54+0.84i, 0.99+0.01i]] # pos 1: rotate!
]
]xq_out = xq_ * freqs_cis
Position 0:
(1.0+2.0i) * (1.0+0.0i) = 1.0+2.0i
(3.0+4.0i) * (1.0+0.0i) = 3.0+4.0i
Position 1:
(5.0+6.0i) * (0.54+0.84i) = 5.0*0.54 - 6.0*0.84 + i(5.0*0.84 + 6.0*0.54)
= 2.7 - 5.04 + i(4.2 + 3.24)
= -2.34 + 7.44i
(7.0+8.0i) * (0.99+0.01i) = 6.85 + 8.0i (approximately)torch.view_as_real(xq_out)
# [1, 2, 1, 2] complex → [1, 2, 1, 2, 2] real
[
[
[[[1.0, 2.0], [3.0, 4.0]]], # pos 0: unchanged
[[[-2.34, 7.44], [6.85, 8.0]]] # pos 1: rotated!
]
].flatten(3) # Flatten last 2 dims
# [1, 2, 1, 2, 2] → [1, 2, 1, 4]
[
[
[[1.0, 2.0, 3.0, 4.0]], # pos 0
[[-2.34, 7.44, 6.85, 8.0]] # pos 1: rotated!
]
]| Function | Purpose | Input → Output |
|---|---|---|
precompute_freqs_cis |
Compute rotation angles | dim, seq_len → complex rotations [seq_len, dim/2] |
reshape_for_broadcast |
Prepare for broadcasting | [seq_len, dim/2] → [1, seq_len, 1, dim/2] |
apply_rotary_emb |
Rotate embeddings | queries, keys → rotated queries, keys |
Each position gets rotated by different angles, encoding its position. The rotation is done by treating consecutive pairs of dimensions as 2D vectors and rotating them in that 2D plane. This creates a positional encoding that naturally captures relative positions through the geometry of rotations.