Skip to content

Instantly share code, notes, and snippets.

@razhangwei
Last active March 31, 2025 00:29
Show Gist options
  • Save razhangwei/e4cce86bdc5e4168a26fdf3579d43d7b to your computer and use it in GitHub Desktop.
Save razhangwei/e4cce86bdc5e4168a26fdf3579d43d7b to your computer and use it in GitHub Desktop.
qwen2.5-vl pseudo code #window attention
from typing import List, Tuple, Optional, Union, Callable
import numpy as np
# Define a Tensor type for clarity in this pseudocode
Tensor = np.ndarray # In real implementation, this would be a framework-specific tensor type
# Qwen2.5-VL Vision Encoder Pseudocode
class Qwen25VisionEncoder:
def __init__(self,
hidden_size: int = 1280,
num_layers: int = 32,
num_heads: int = 16,
intermediate_size: int = 3456,
patch_size: int = 14,
window_size: int = 112,
full_attention_block_indexes: List[int] = [7, 15, 23, 31]):
self.hidden_size: int = hidden_size
self.num_layers: int = num_layers
self.num_heads: int = num_heads
self.patch_size: int = patch_size
self.window_size: int = window_size # 8 × 8 patches
self.full_attention_block_indexes: List[int] = full_attention_block_indexes
# Initialize transformer layers
self.layers: List[Union[WindowAttentionBlock, FullAttentionBlock]] = []
for i in range(num_layers):
if i in full_attention_block_indexes:
# Full attention layer
self.layers.append(FullAttentionBlock(hidden_size, num_heads, intermediate_size))
else:
# Window attention layer
self.layers.append(WindowAttentionBlock(hidden_size, num_heads, intermediate_size, window_size))
# Initialize 2D RoPE for spatial positions
self.rotary_emb: MultimodalRotaryEmbedding = MultimodalRotaryEmbedding(hidden_size // num_heads)
# MLP for feature compression (vision-language merger)
self.feature_compressor: VisionLanguageMerger = VisionLanguageMerger(in_channels=hidden_size, out_channels=8192) # For 72B model
def preprocess_image(self, image: Tensor) -> Tensor:
"""Resize image height and width to multiples of 28"""
h, w = image.shape[0], image.shape[1]
h = (h + 27) // 28 * 28 # Round up to nearest multiple of 28
w = (w + 27) // 28 * 28
return resize_image(image, (h, w))
def extract_patches(self, image: Union[Tensor, List[Tensor]], is_video: bool = False) -> Tuple[List[Tensor], Tuple[List[int], List[int], List[int]]]:
"""Extract 14×14 patches from image or video"""
if not is_video:
# For static image
patches: List[Tensor] = []
for i in range(0, image.height, self.patch_size):
for j in range(0, image.width, self.patch_size):
if i + self.patch_size <= image.height and j + self.patch_size <= image.width:
patch = image[i:i+self.patch_size, j:j+self.patch_size]
patches.append(patch)
# Create position IDs for patches
h_pos_ids: List[int] = [i // self.patch_size for i in range(0, image.height, self.patch_size)]
w_pos_ids: List[int] = [j // self.patch_size for j in range(0, image.width, self.patch_size)]
t_pos_ids: List[int] = [0] * len(patches) # All same temporal position for image
else:
# For video, group two consecutive frames
patches: List[Tensor] = []
t_pos_ids: List[int] = []
h_pos_ids: List[int] = []
w_pos_ids: List[int] = []
for t in range(0, len(image), 2): # Step by 2 frames
frame1 = image[t]
frame2 = image[min(t+1, len(image)-1)] # Handle potential odd number of frames
for i in range(0, frame1.height, self.patch_size):
for j in range(0, frame1.width, self.patch_size):
if i + self.patch_size <= frame1.height and j + self.patch_size <= frame1.width:
# Extract patches from both frames and combine
patch1 = frame1[i:i+self.patch_size, j:j+self.patch_size]
patch2 = frame2[i:i+self.patch_size, j:j+self.patch_size]
combined_patch = combine_patches(patch1, patch2)
patches.append(combined_patch)
t_pos_ids.append(t) # Timestamp (aligned to absolute time in video)
h_pos_ids.append(i // self.patch_size)
w_pos_ids.append(j // self.patch_size)
return patches, (t_pos_ids, h_pos_ids, w_pos_ids)
def forward(self, visual_input: Union[Tensor, List[Tensor]], is_video: bool = False) -> Tensor:
# Preprocess and extract patches
if not is_video:
preprocessed: Tensor = self.preprocess_image(visual_input)
else:
preprocessed: List[Tensor] = [self.preprocess_image(frame) for frame in visual_input]
patches, (t_pos_ids, h_pos_ids, w_pos_ids) = self.extract_patches(preprocessed, is_video)
# Initial embedding of patches
x: Tensor = linear_projection(patches, self.hidden_size)
# Process through transformer layers with positional encoding
for i, layer in enumerate(self.layers):
# Apply Multimodal Rotary Position Embedding (MRoPE)
rotary_emb_data: Tuple[Tensor, Tensor, Tensor] = self.rotary_emb(t_pos_ids, h_pos_ids, w_pos_ids)
# Process through transformer layer (window or full attention)
x = layer(x, rotary_emb_data)
# Compress features for LLM integration
grouped_features: Tensor = group_adjacent_features(x, group_size=4)
compressed_features: Tensor = self.feature_compressor(grouped_features)
return compressed_features
class WindowAttentionBlock:
def __init__(self, hidden_size: int, num_heads: int, intermediate_size: int, window_size: int):
self.attention: MultiHeadWindowAttention = MultiHeadWindowAttention(hidden_size, num_heads, window_size)
self.ffn: FFN = FFN(hidden_size, intermediate_size, activation="SwiGLU")
self.norm1: RMSNorm = RMSNorm(hidden_size)
self.norm2: RMSNorm = RMSNorm(hidden_size)
def forward(self, x: Tensor, rotary_emb_data: Tuple[Tensor, Tensor, Tensor]) -> Tensor:
# Attention with residual connection and normalization
residual: Tensor = x
x = self.norm1(x)
x = self.attention(x, rotary_emb_data)
x = x + residual
# FFN with residual connection and normalization
residual = x
x = self.norm2(x)
x = self.ffn(x)
x = x + residual
return x
class MultiHeadWindowAttention:
def __init__(self, hidden_size: int, num_heads: int, window_size: int):
self.hidden_size: int = hidden_size
self.num_heads: int = num_heads
self.window_size: int = window_size # Window size in number of tokens
self.head_dim: int = hidden_size // num_heads
self.q_proj: Linear = Linear(hidden_size, hidden_size)
self.k_proj: Linear = Linear(hidden_size, hidden_size)
self.v_proj: Linear = Linear(hidden_size, hidden_size)
self.out_proj: Linear = Linear(hidden_size, hidden_size)
def forward(self, x: Tensor, rotary_emb_data: Tuple[Tensor, Tensor, Tensor]) -> Tensor:
batch_size, seq_len, _ = x.shape
# Project queries, keys, values
q: Tensor = self.q_proj(x) # [B, L, C]
k: Tensor = self.k_proj(x) # [B, L, C]
v: Tensor = self.v_proj(x) # [B, L, C]
# Reshape for multi-head attention
q = reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)) # [B, L, NH, C/NH]
k = reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)) # [B, L, NH, C/NH]
v = reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)) # [B, L, NH, C/NH]
# Apply rotary position embeddings (MRoPE) - this handles spatial/temporal relationships
q = apply_rotary_emb(q, rotary_emb_data)
k = apply_rotary_emb(k, rotary_emb_data)
# Calculate number of windows in the sequence
num_windows: int = (seq_len + self.window_size - 1) // self.window_size
# Pad if necessary to make sequence length divisible by window_size
pad_len: int = (num_windows * self.window_size) - seq_len
if pad_len > 0:
q = pad(q, (0, 0, 0, 0, 0, pad_len))
k = pad(k, (0, 0, 0, 0, 0, pad_len))
v = pad(v, (0, 0, 0, 0, 0, pad_len))
padded_seq_len: int = seq_len + pad_len
# Reshape to [B, num_windows, window_size, NH, C/NH]
q = reshape(q, (batch_size, num_windows, self.window_size, self.num_heads, self.head_dim))
k = reshape(k, (batch_size, num_windows, self.window_size, self.num_heads, self.head_dim))
v = reshape(v, (batch_size, num_windows, self.window_size, self.num_heads, self.head_dim))
# Transpose to get [B, num_windows, NH, window_size, C/NH]
q = transpose(q, (0, 1, 3, 2, 4))
k = transpose(k, (0, 1, 3, 2, 4))
v = transpose(v, (0, 1, 3, 2, 4))
# Calculate attention within each window
# [B, num_windows, NH, window_size, C/NH] × [B, num_windows, NH, C/NH, window_size]
scores: Tensor = matmul(q, transpose(k, (0, 1, 2, 4, 3))) / sqrt(self.head_dim)
# Result: [B, num_windows, NH, window_size, window_size]
# Apply softmax to get attention weights
attn_weights: Tensor = softmax(scores, dim=-1) # [B, num_windows, NH, window_size, window_size]
# Apply attention weights to values
output: Tensor = matmul(attn_weights, v) # [B, num_windows, NH, window_size, C/NH]
# Transpose back to [B, num_windows, window_size, NH, C/NH]
output = transpose(output, (0, 1, 3, 2, 4))
# Reshape back to [B, padded_seq_len, NH, C/NH]
output = reshape(output, (batch_size, padded_seq_len, self.num_heads, self.head_dim))
# Remove padding if added
if pad_len > 0:
output = output[:, :seq_len, :, :]
# Reshape to [B, seq_len, C]
output = reshape(output, (batch_size, seq_len, self.hidden_size))
# Output projection
output = self.out_proj(output)
return output
class MultimodalRotaryEmbedding:
def __init__(self, dim: int):
self.dim: int = dim
self.freqs: Tensor = precompute_freqs(dim)
def __call__(self, t_ids: List[int], h_ids: List[int], w_ids: List[int]) -> Tuple[Tensor, Tensor, Tensor]:
"""Generate rotary embeddings for temporal, height, and width dimensions"""
t_emb: Tensor = compute_rotary_embedding(t_ids, self.freqs)
h_emb: Tensor = compute_rotary_embedding(h_ids, self.freqs)
w_emb: Tensor = compute_rotary_embedding(w_ids, self.freqs)
return (t_emb, h_emb, w_emb)
class VisionLanguageMerger:
def __init__(self, in_channels: int, out_channels: int):
self.mlp: Sequential = Sequential(
Linear(in_channels * 4, in_channels * 2), # First compress 4 grouped features
SwiGLU(),
Linear(in_channels * 2, out_channels) # Project to LLM embedding dimension
)
def forward(self, grouped_features: Tensor) -> Tensor:
# Flatten grouped features
flat_features: Tensor = reshape(grouped_features, (grouped_features.shape[0], grouped_features.shape[1] * 4))
# Project to LLM dimension
compressed: Tensor = self.mlp(flat_features)
return compressed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment