Skip to content

Instantly share code, notes, and snippets.

@mukul54
Created October 1, 2025 04:00
Show Gist options
  • Save mukul54/6a2c0dd4fc907db944ce637212150233 to your computer and use it in GitHub Desktop.
Save mukul54/6a2c0dd4fc907db944ce637212150233 to your computer and use it in GitHub Desktop.
Rope Code LLaMA

RoPE (Rotary Position Embedding) Explained

The Mathematics Behind RoPE

Core Concept

RoPE encodes positional information by rotating embedding vectors in a way that:

  1. Preserves relative positions: The dot product between tokens depends on their relative distance
  2. Uses rotation: Each position gets rotated by an angle proportional to its position
  3. Works in pairs: Dimensions are grouped in pairs and rotated together

The Mathematical Formula

For a position m and dimension pair (2i, 2i+1), RoPE applies a 2D rotation:

$$ \begin{bmatrix} q_{2i}^{(m)} \ q_{2i+1}^{(m)} \end{bmatrix} = \begin{bmatrix} \cos(m\theta_i) & -\sin(m\theta_i) \ \sin(m\theta_i) & \cos(m\theta_i) \end{bmatrix} \begin{bmatrix} q_{2i} \ q_{2i+1} \end{bmatrix} $$

Where θ_i is the rotation frequency for dimension pair i:

$$ \theta_i = \frac{1}{\text{base}^{2i/d}} = \text{base}^{-2i/d} $$

  • base = 10000 (typically)
  • d = total embedding dimension
  • i = dimension pair index (0, 1, 2, ...)

Why This Works

  1. Different frequencies for different dimensions: Lower dimensions rotate faster (smaller θ), higher dimensions rotate slower (larger θ)
  2. Complex number representation: The rotation can be elegantly expressed as multiplying by $e^{im\theta_i}$
  3. Relative position encoding: The angle between positions m₁ and m₂ is $(m_1 - m_2)\theta_i$

The exponential decay ensures each dimension pair operates at a different "wavelength" for encoding position!


Complete RoPE Code Explanation with Examples

Let me break down all three functions with simple, concrete examples.

Function 1: precompute_freqs_cis

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

Purpose

Precomputes the rotation angles for all positions and all dimension pairs.

Example: dim=4, end=3, theta=10000

Step 1: Compute base frequencies

torch.arange(0, 4, 2)  # [0, 2]
[: (4 // 2)]           # [0, 2] (first 2 elements)
.float() / 4           # [0.0, 0.5]
theta ** (...)         # [10000^0.0, 10000^0.5] = [1.0, 100.0]
freqs = 1.0 / (...)    # [1.0, 0.01]

Result: freqs = [1.0, 0.01] (2 frequency values for 2 dimension pairs)

Step 2: Create position indices

t = torch.arange(3)  # [0, 1, 2]  (positions 0, 1, 2)

Step 3: Compute angles for all position-frequency combinations

freqs = torch.outer(t, freqs)
# Outer product: each position × each frequency

     [1.0,  0.01]
[0]  [0.0,  0.0 ]
[1]  [1.0,  0.01]
[2]  [2.0,  0.02]

Meaning:

  • Position 0: rotations [0.0, 0.0]
  • Position 1: rotations [1.0, 0.01]
  • Position 2: rotations [2.0, 0.02]

Step 4: Convert to complex numbers (rotation representation)

freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
# polar(magnitude, angle) = magnitude * e^(i*angle)
# magnitude = 1.0 for all

Result: Complex numbers representing rotations

Position 0: [e^(i*0.0),  e^(i*0.0) ] = [1.0+0i,     1.0+0i    ]
Position 1: [e^(i*1.0),  e^(i*0.01)] = [0.54+0.84i, 0.99+0.01i]
Position 2: [e^(i*2.0),  e^(i*0.02)] = [-0.42+0.91i, 0.99+0.02i]

Shape: [3, 2] → (3 positions, 2 dimension pairs)


Function 2: reshape_for_broadcast

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

Purpose

Reshape freqs_cis to broadcast correctly with the input tensor.

Example

Suppose:

  • x has shape [batch, seq_len, num_heads, head_dim] = [2, 3, 4, 4]
  • After converting to complex: x_complex has shape [2, 3, 4, 2] (last dim halved)
  • freqs_cis has shape [3, 2] (seq_len=3, head_dim/2=2)

Step-by-Step:

ndim = 4  # x_complex has 4 dimensions

# Check: freqs_cis.shape should be (x.shape[1], x.shape[-1])
# freqs_cis.shape = (3, 2) ✓
# x.shape[1] = 3 (seq_len), x.shape[-1] = 2 (head_dim/2) ✓

# Build new shape
shape = []
for i, d in enumerate([2, 3, 4, 2]):
    if i == 1 or i == 3:  # Keep dimensions 1 and 3
        shape.append(d)
    else:
        shape.append(1)  # Add singleton dimensions

shape = [1, 3, 1, 2]

Result: freqs_cis.view(1, 3, 1, 2)

Why? This allows broadcasting:

x_complex:      [2, 3, 4, 2]
freqs_cis:      [1, 3, 1, 2]
                 ↓  ↓  ↓  ↓
After broadcast:[2, 3, 4, 2]  (multiplication works!)

Function 3: apply_rotary_emb

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

Purpose

Apply rotary embeddings to query and key tensors.

Complete Example

Let's trace through with concrete numbers!

Input:

xq.shape = [1, 2, 1, 4]  # [batch, seq_len, heads, head_dim]
xk.shape = [1, 2, 1, 4]
freqs_cis.shape = [2, 2]  # [seq_len, head_dim/2]

# Suppose xq looks like:
xq = [
  [  # batch 0
    [[1.0, 2.0, 3.0, 4.0]],  # position 0, head 0
    [[5.0, 6.0, 7.0, 8.0]]   # position 1, head 0
  ]
]

Step 1: Reshape to pair dimensions

xq.float().reshape(*xq.shape[:-1], -1, 2)
# [1, 2, 1, 4] → [1, 2, 1, 2, 2]
#                 batch, seq, heads, pairs, (real,imag)

[
  [
    [[[1.0, 2.0], [3.0, 4.0]]],  # pos 0: pairs (1,2) and (3,4)
    [[[5.0, 6.0], [7.0, 8.0]]]   # pos 1: pairs (5,6) and (7,8)
  ]
]

Step 2: Convert to complex numbers

xq_ = torch.view_as_complex(...)
# Shape: [1, 2, 1, 2]  (complex numbers)

[
  [
    [[1.0+2.0i, 3.0+4.0i]],  # pos 0
    [[5.0+6.0i, 7.0+8.0i]]   # pos 1
  ]
]

Step 3: Reshape freqs_cis for broadcasting

freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
# [2, 2] → [1, 2, 1, 2]

# Suppose freqs_cis contains:
[
  [
    [[1.0+0.0i, 1.0+0.0i]],     # pos 0: no rotation
    [[0.54+0.84i, 0.99+0.01i]]  # pos 1: rotate!
  ]
]

Step 4: Apply rotation (complex multiplication)

xq_out = xq_ * freqs_cis

Position 0:
  (1.0+2.0i) * (1.0+0.0i) = 1.0+2.0i
  (3.0+4.0i) * (1.0+0.0i) = 3.0+4.0i

Position 1:
  (5.0+6.0i) * (0.54+0.84i) = 5.0*0.54 - 6.0*0.84 + i(5.0*0.84 + 6.0*0.54)
                             = 2.7 - 5.04 + i(4.2 + 3.24)
                             = -2.34 + 7.44i
  
  (7.0+8.0i) * (0.99+0.01i) = 6.85 + 8.0i (approximately)

Step 5: Convert back to real numbers

torch.view_as_real(xq_out)
# [1, 2, 1, 2] complex → [1, 2, 1, 2, 2] real

[
  [
    [[[1.0, 2.0], [3.0, 4.0]]],      # pos 0: unchanged
    [[[-2.34, 7.44], [6.85, 8.0]]]   # pos 1: rotated!
  ]
]

Step 6: Flatten back to original shape

.flatten(3)  # Flatten last 2 dims
# [1, 2, 1, 2, 2] → [1, 2, 1, 4]

[
  [
    [[1.0, 2.0, 3.0, 4.0]],      # pos 0
    [[-2.34, 7.44, 6.85, 8.0]]   # pos 1: rotated!
  ]
]

Summary

Function Purpose Input → Output
precompute_freqs_cis Compute rotation angles dim, seq_len → complex rotations [seq_len, dim/2]
reshape_for_broadcast Prepare for broadcasting [seq_len, dim/2] → [1, seq_len, 1, dim/2]
apply_rotary_emb Rotate embeddings queries, keys → rotated queries, keys

Each position gets rotated by different angles, encoding its position. The rotation is done by treating consecutive pairs of dimensions as 2D vectors and rotating them in that 2D plane. This creates a positional encoding that naturally captures relative positions through the geometry of rotations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment