Skip to content

Instantly share code, notes, and snippets.

@cat-state
Created April 23, 2025 15:55
Show Gist options
  • Save cat-state/3d5e02785127e0f13371787a1a1c40a6 to your computer and use it in GitHub Desktop.
Save cat-state/3d5e02785127e0f13371787a1a1c40a6 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention.flex_attention import flex_attention, create_block_mask, BlockMask
import dataclasses
# torch.use_deterministic_algorithms(True)
torch._dynamo.config.cache_size_limit = 128
@dataclasses.dataclass
class Config:
bos_token_id: int = 151643
eos_token_id: int = 151643
hidden_act: str = "silu"
hidden_size: int = 1536
initializer_range: float = 0.02
intermediate_size: int = 8960
max_position_embeddings: int = 131072
max_window_layers: int = 21
model_type: str = "qwen2"
num_attention_heads: int = 12
num_hidden_layers: int = 28
num_key_value_heads: int = 2
rms_norm_eps: float = 1e-6
rope_theta: int = 10000
sliding_window: int = 4096
tie_word_embeddings: bool = False
dtype: torch.dtype = torch.bfloat16
use_cache: bool = True
use_mrope: bool = False
use_sliding_window: bool = False
vocab_size: int = 151936
# @torch.compile(fullgraph=True)
def rmsnorm(x, w, eps=1e-6):
ori_dtype = x.dtype
x = x.to(torch.float32)
variance = (x * x).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
return w.expand_as(x) * x.to(ori_dtype)
# @torch.compile(fullgraph=True)
def rotary(x, pos_idx, rope_theta):
# x: [batch, seqlen, nheads, headdim] or [batch, seqlen, headdim]
# pos_idx: [seqlen] or [batch, seqlen]
ori_dtype = x.dtype
x = x.to(torch.float32)
shape = x.shape
x = x.view(-1, shape[-2], shape[-1]) if len(shape) == 4 else x
dim = x.shape[-1]
half_dim = dim // 2
# More efficient freq computation - only compute for half_dim//2 since we repeat
inv_freq = (1.0 / (rope_theta ** (torch.arange(0, half_dim, 2, device=x.device, dtype=x.dtype) / half_dim)))
# Compute freqs directly for pairs of dims
freqs = pos_idx.view(-1, 1) * inv_freq
# Compute sin and cos once for pairs
cos = torch.cos(freqs)
sin = torch.sin(freqs)
# Efficient interleaving without repeat_interleave
cos = torch.stack([cos, cos], dim=-1).flatten(-2) # [seqlen, dim//2]
sin = torch.stack([sin, sin], dim=-1).flatten(-2) # [seqlen, dim//2]
# Apply rotary embedding
x1, x2 = x[..., :half_dim], x[..., half_dim:]
# Single concatenation operation
out = torch.cat([
x1 * cos.unsqueeze(0) - x2 * sin.unsqueeze(0),
x2 * cos.unsqueeze(0) + x1 * sin.unsqueeze(0)
], dim=-1)
return out.view(shape).to(ori_dtype)
PRONT = print
class FFFormerLayer(nn.Module):
up_gate: nn.Linear
down: nn.Linear
o: nn.Linear
ilnw: nn.Parameter
alnw: nn.Parameter
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
self.up_gate = nn.Linear(cfg.hidden_size, cfg.intermediate_size * 2, bias=False, dtype=cfg.dtype)
self.down = nn.Linear(cfg.intermediate_size, cfg.hidden_size, bias=False, dtype=cfg.dtype)
# Fix QKV dimensions
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
self.num_heads = cfg.num_attention_heads
self.num_kv_heads = cfg.num_key_value_heads
# Combined QKV with merged biases
total_qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_kv_heads)
self.qkv = nn.Linear(cfg.hidden_size, total_qkv_dim, bias=True, dtype=cfg.dtype)
# Fix split points
self.q_size = self.head_dim * self.num_heads
self.kv_size = self.head_dim * self.num_kv_heads
self.k_start = self.q_size
self.v_start = self.k_start + self.kv_size
self.o = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False, dtype=cfg.dtype)
self.ilnw = nn.Parameter(torch.ones([cfg.hidden_size], dtype=torch.float32))
self.alnw = nn.Parameter(torch.ones([cfg.hidden_size], dtype=torch.float32))
def forward(self, x, kcache, vcache, write_idxs, position_idxs, attn_mask=None, causal=True):
N, T, C = x.shape
# Apply layer norm and QKV projection
x_ln = rmsnorm(x, self.ilnw, self.cfg.rms_norm_eps)
# PRONT(f"my input{x_ln}")
# PRONT(f"my qkv {self.qkv.weight}")
qkv = self.qkv(x_ln.to(self.cfg.dtype)) # Bias is automatically added by nn.Linear
# Split QKV
q = qkv[..., :self.q_size]
k = qkv[..., self.k_start:self.v_start]
v = qkv[..., self.v_start:]
# Reshape and transpose
# q = q.reshape(N, T * (self.num_heads // self.num_kv_heads), self.num_kv_heads, -1).transpose(1, 2).contiguous() # [N, num_heads, T, head_dim]
q = q.reshape(N, T, self.num_heads, -1).transpose(1, 2).contiguous() # [N, num_heads, T, head_dim]
k = k.reshape(N, T, self.num_kv_heads, -1).transpose(1, 2).contiguous() # [N, num_kv_heads, T, head_dim]
v = v.reshape(N, T, self.num_kv_heads, -1).transpose(1, 2).contiguous() # [N, num_kv_heads, T, head_dim]
# PRONT(f"my q{q[0, 0, 0, :10]}")
# q = rotary(q, position_idxs.repeat_interleave((self.num_heads // self.num_kv_heads), 0), self.cfg.rope_theta)
# print(position_idxs)
q = rotary(q, position_idxs, self.cfg.rope_theta)
k = rotary(k, position_idxs, self.cfg.rope_theta)
# PRONT(f"my qrot{q}")
if kcache is None:
kcache, vcache = k, v
else:
# print(kcache.shape, k.shape, write_idxs.shape, write_idxs)
# kcache[torch.arange(2), torch.arange(2, device=kcache.device), write_idxs
batch_idx = torch.arange(N, device=kcache.device)[:, None, None] # (N, 1, 1)
head_idx = torch.arange(self.num_kv_heads, device=kcache.device)[None, :, None] # (1, H, 1)
write_idxs = write_idxs[:, None, :] # (N, 1, S)
kcache[batch_idx, head_idx, write_idxs, :] = k
vcache[batch_idx, head_idx, write_idxs, :] = v
# kcache.scatter_(2, write_idxs, k)
# vcache.scatter_(2, write_idxs, v)
# attn = flex_attention(q, kcache, vcache, block_mask=attn_mask)
if T > 1 or causal:
attn = F.scaled_dot_product_attention(q, kcache, vcache, is_causal=causal, enable_gqa=True)
attn = attn.transpose(1, 2).reshape(N, T, C)
else:
q = q.reshape(N, self.num_kv_heads, self.num_heads // self.num_kv_heads, -1)
attn = F.scaled_dot_product_attention(q, kcache, vcache, is_causal=False)
attn = attn.reshape(N, self.num_heads, T, -1).transpose(1, 2).reshape(N, T, C)
# attn = attn.transpose(1, 2).reshape(N, T, C)
# PRONT(f"my attn{attn}")
x = x + self.o(attn)
up_gate = self.up_gate(rmsnorm(x, self.alnw, self.cfg.rms_norm_eps).to(self.cfg.dtype))
mid = F.silu(up_gate[..., :self.cfg.intermediate_size]) * up_gate[..., self.cfg.intermediate_size:]
x = x + self.down(mid)
# PRONT(f"my output{x}")
# globals()["PRONT"] = lambda x: x
return x
class FFFormer(nn.Module):
embed: nn.Embedding
unembed: nn.Linear
oln: nn.Parameter
layers: nn.ModuleList # [FFFormerLayer]
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
self.embed = nn.Embedding(cfg.vocab_size, cfg.hidden_size, dtype=cfg.dtype)
self.unembed = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False, dtype=cfg.dtype)
self.oln = nn.Parameter(torch.ones([cfg.hidden_size], dtype=torch.float32))
self.layers = nn.ModuleList([FFFormerLayer(cfg) for _ in range((cfg.num_hidden_layers))])
def make_mask(self, input_ids):
N, T = input_ids.shape
qpkv = self.cfg.num_attention_heads // self.cfg.num_key_value_heads
def causalish(b, h, q_idx, kv_idx):
return (q_idx // qpkv) >= kv_idx
mask = create_block_mask(causalish, N, None, T * qpkv, T)
return mask
q_heads_per_kv = self.cfg.num_attention_heads // self.cfg.num_key_value_heads
attn_mask = torch.ones(N, T, T, dtype=torch.bool, device="cuda").tril(diagonal=0).repeat_interleave(q_heads_per_kv, 1)
return attn_mask
@torch.compile(dynamic=False)
def forward(self, input_ids, kcache, vcache, write_idxs, position_idxs, attn_mask, causal=True):
x = self.embed(input_ids).to(self.cfg.dtype)
for i, l in enumerate(self.layers):
kcache_i, vcache_i = (None, None) if kcache is None else (kcache[i], vcache[i])
x = l(x, kcache_i, vcache_i, write_idxs, position_idxs, attn_mask, causal=causal)
x = rmsnorm(x, self.oln, self.cfg.rms_norm_eps).to(self.cfg.dtype)
return self.unembed(x)
@staticmethod
def from_llamaish_hf_model(hf_cfg, state_dict):
state_dict = hf_model.state_dict()
hf_cfg = hf_model.config
cfg = Config(
vocab_size=hf_cfg.vocab_size,
hidden_size=hf_cfg.hidden_size,
num_hidden_layers=hf_cfg.num_hidden_layers,
num_attention_heads=hf_cfg.num_attention_heads,
num_key_value_heads=hf_cfg.num_key_value_heads,
rms_norm_eps=hf_cfg.rms_norm_eps,
dtype=torch.bfloat16,
intermediate_size=hf_cfg.intermediate_size,
max_position_embeddings=hf_cfg.max_position_embeddings,
max_window_layers=hf_cfg.max_window_layers,
sliding_window=hf_cfg.sliding_window,
tie_word_embeddings=hf_cfg.tie_word_embeddings,
use_cache=hf_cfg.use_cache,
use_mrope=hf_cfg.use_mrope,
use_sliding_window=hf_cfg.use_sliding_window,
initializer_range=hf_cfg.initializer_range,
hidden_act=hf_cfg.hidden_act,
bos_token_id=hf_cfg.bos_token_id,
eos_token_id=hf_cfg.eos_token_id,
rope_theta=hf_cfg.rope_theta
)
with torch.device('meta'):
model = FFFormer(cfg)
new_state = {}
# Track which source keys we've handled
handled_keys = set()
# Direct mappings with regex patterns
key_maps = {
r'model\.embed_tokens\.weight': ('embed.weight', cfg.dtype),
r'model\.norm\.weight': ('oln', cfg.dtype),
r'lm_head\.weight': ('unembed.weight', cfg.dtype),
r'model\.layers\.(\d+)\.input_layernorm\.weight': lambda m: (f'layers.{m.group(1)}.ilnw', cfg.dtype),
r'model\.layers\.(\d+)\.post_attention_layernorm\.weight': lambda m: (f'layers.{m.group(1)}.alnw', cfg.dtype),
r'model\.layers\.(\d+)\.self_attn\.o_proj\.weight': lambda m: (f'layers.{m.group(1)}.o.weight', cfg.dtype),
r'model\.layers\.(\d+)\.mlp\.down_proj\.weight': lambda m: (f'layers.{m.group(1)}.down.weight', cfg.dtype),
}
# First pass - direct mappings and collecting weights for merging
qkv_weights = {}
qkv_biases = {}
upgate_weights = {}
import re
for k, v in state_dict.items():
# Try direct mappings first
mapped = False
for pattern, target in key_maps.items():
match = re.match(pattern, k)
if match:
if callable(target):
new_key, dtype = target(match)
else:
new_key, dtype = target
new_state[new_key] = v.to(dtype).detach().clone()
handled_keys.add(k)
mapped = True
break
if mapped:
continue
# Collect QKV weights and biases
qkv_match = re.match(r'model\.layers\.(\d+)\.self_attn\.(q|k|v)_proj\.(weight|bias)', k)
if qkv_match:
layer, qkv_type, param_type = qkv_match.groups()
if param_type == 'weight':
qkv_weights.setdefault(layer, {})[qkv_type] = v
else:
qkv_biases.setdefault(layer, {})[qkv_type] = v
handled_keys.add(k)
continue
# Collect up/gate weights
upgate_match = re.match(r'model\.layers\.(\d+)\.mlp\.(up|gate)_proj\.weight', k)
if upgate_match:
layer, proj_type = upgate_match.groups()
upgate_weights.setdefault(layer, {})[proj_type] = v
handled_keys.add(k)
continue
# Check for unhandled source keys
unhandled_keys = set(state_dict.keys()) - handled_keys
if unhandled_keys:
raise ValueError(f"Found unhandled keys in source state dict: {unhandled_keys}")
# Second pass - merge collected weights
for layer in range(cfg.num_hidden_layers):
layer = str(layer)
# Merge QKV weights and set biases
if layer in qkv_weights and layer in qkv_biases:
qkv_w_dict = qkv_weights[layer]
qkv_b_dict = qkv_biases[layer]
if not all(k in qkv_w_dict for k in ['q', 'k', 'v']) or not all(k in qkv_b_dict for k in ['q', 'k', 'v']):
raise ValueError(f"Missing QKV weights/biases for layer {layer}")
# Merge weights and biases
q_w, k_w, v_w = qkv_w_dict['q'], qkv_w_dict['k'], qkv_w_dict['v']
q_b, k_b, v_b = qkv_b_dict['q'], qkv_b_dict['k'], qkv_b_dict['v']
# Merge weights
qkv_merged = torch.cat([q_w, k_w, v_w], dim=0)
new_state[f'layers.{layer}.qkv.weight'] = qkv_merged
# Merge biases
qkv_bias_merged = torch.cat([q_b, k_b, v_b])
new_state[f'layers.{layer}.qkv.bias'] = qkv_bias_merged
else:
raise ValueError(f"No QKV weights/biases found for layer {layer}")
# Merge up/gate
if layer in upgate_weights:
upgate_dict = upgate_weights[layer]
if not all(k in upgate_dict for k in ['gate', 'up']):
raise ValueError(f"Missing up/gate weights for layer {layer}")
gate, up = upgate_dict['gate'], upgate_dict['up']
upgate_merged = torch.cat([gate, up], dim=0)
new_state[f'layers.{layer}.up_gate.weight'] = upgate_merged
else:
raise ValueError(f"No up/gate weights found for layer {layer}")
# Verify all required keys are present in new state dict
expected_keys = set(model.state_dict().keys())
missing_keys = expected_keys - set(new_state.keys())
if missing_keys:
raise ValueError(f"Missing required keys in converted state dict: {missing_keys}")
# Load the state dict
new_state = {k: v.to(cfg.dtype) for k, v in new_state.items()}
model.load_state_dict(new_state, strict=True, assign=True)
for k, p in model.named_parameters():
assert p.device != 'meta', k
return model
import tqdm
@torch.inference_mode
def generate(model: FFFormer, input_ids: torch.Tensor, new_tokens: int):
prefill_length = input_ids.shape[1]
kcache = torch.full((model.cfg.num_hidden_layers, input_ids.shape[0], model.cfg.num_key_value_heads, input_ids.shape[1] + new_tokens, 128), fill_value=float("-inf"), dtype=model.cfg.dtype).cuda()
vcache = torch.full((model.cfg.num_hidden_layers, input_ids.shape[0], model.cfg.num_key_value_heads, input_ids.shape[1] + new_tokens, 128), fill_value=float("-inf"), dtype=model.cfg.dtype).cuda()
prefill_write_idxs = torch.arange(input_ids.shape[1])[None, :].cuda()
prefill_position_idxs = torch.arange(input_ids.shape[1]).cuda()
mask = model.make_mask(input_ids)
prefill_logits = model(input_ids, kcache[:, :, :, :prefill_length], vcache[:, :, :, :prefill_length], prefill_write_idxs, prefill_position_idxs, attn_mask=None, causal=True)
last_output_token = torch.argmax(prefill_logits[:, -1:], dim=-1)
all_output_tokens = [last_output_token]
write_idxs = prefill_write_idxs[:, -1:] + 1
position_idxs = prefill_position_idxs[-1:] + 1
for i in range(new_tokens):
chunk_length = 256 * ((prefill_length + i + 256 - 1) // 256)
print(chunk_length)
output_logits = model(last_output_token, kcache[:, :, :, :chunk_length], vcache[:, :, :, :chunk_length], write_idxs, position_idxs, attn_mask=None, causal=False)
#cur_tokens = torch.cat([input_ids] + all_output_tokens, dim=1)
#output_logits = model(cur_tokens, None, None, None, position_idxs, attn_mask=mask, causal=False)
output_logits = F.softmax(output_logits[:, prefill_length+i-1:prefill_length+i], dim=-1)
# next_output_token = torch.multinomial(output_logits, 1)
next_output_token = torch.argmax(output_logits, dim=-1)
write_idxs = write_idxs.add_(1)
position_idxs = position_idxs.add_(1)
last_output_token = next_output_token
all_output_tokens.append(next_output_token)
return torch.cat(all_output_tokens, dim=1)
if True or __name__ == "__main__":
import time
import torch.cuda
from transformers import AutoModelForCausalLM, AutoTokenizer
def benchmark_model(model: FFFormer, inputs, num_runs=50, warmup=10):
# Move everything to GPU and clear cache
input_ids = inputs["input_ids"].cuda()
torch.cuda.synchronize()
torch.cuda.empty_cache()
kcache = torch.zeros(model.cfg.num_hidden_layers, model.cfg.num_key_value_heads, input_ids.shape[1], 128, dtype=model.cfg.dtype).cuda()
vcache = torch.zeros(model.cfg.num_hidden_layers, model.cfg.num_key_value_heads, input_ids.shape[1], 128, dtype=model.cfg.dtype).cuda()
write_idxs = torch.arange(input_ids.shape[1])[None, None, :, None].cuda()
position_idxs = torch.arange(input_ids.shape[1]).cuda()
mask = model.make_mask(input_ids)
# Warmup runs
for _ in range(warmup):
with torch.no_grad():
_ = model(input_ids, kcache, vcache, write_idxs, position_idxs, mask)
# Measure GPU memory
torch.cuda.synchronize()
memory_allocated = torch.cuda.memory_allocated() / 1024**2 # MB
memory_reserved = torch.cuda.memory_reserved() / 1024**2 # MB
# Benchmark runs
torch.cuda.synchronize()
start_time = time.perf_counter()
for _ in range(num_runs):
with torch.no_grad():
_ = model(input_ids, kcache, vcache, write_idxs, position_idxs, mask)
torch.cuda.synchronize()
end_time = time.perf_counter()
avg_time = (end_time - start_time) / num_runs
return avg_time, memory_allocated, memory_reserved
def benchmark_hf_model(model, inputs, num_runs=50, warmup=10):
# Move everything to GPU and clear cache
inputs = {k: v.cuda() for k, v in inputs.items()}
torch.cuda.synchronize()
torch.cuda.empty_cache()
# Warmup runs
for _ in range(warmup):
with torch.no_grad():
_ = model(**inputs)
# Measure GPU memory
torch.cuda.synchronize()
memory_allocated = torch.cuda.memory_allocated() / 1024**2 # MB
memory_reserved = torch.cuda.memory_reserved() / 1024**2 # MB
# Benchmark runs
torch.cuda.synchronize()
start_time = time.perf_counter()
for _ in range(num_runs):
with torch.no_grad():
_ = model(**inputs)
torch.cuda.synchronize()
end_time = time.perf_counter()
avg_time = (end_time - start_time) / num_runs
return avg_time, memory_allocated, memory_reserved
print("\n=== Starting Benchmark ===")
# Load models
model_name = "Qwen/Qwen2.5-7B"
# model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
hf_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto").cuda()
our_model = FFFormer.from_llamaish_hf_model(hf_model.config, hf_model.state_dict()).cuda()
hf_model.eval()
our_model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Prepare inputs of different lengths
sequences = [
"Hello.<think>", # Short
"The quick brown fox jumps over the lazy dog. " * 16, # Medium
"The quick brown fox jumps over the lazy dog. " * 128, # Long
"The quick brown fox jumps over the lazy dog. " * 1280, # Long
]
torch.set_grad_enabled(False)
inputs = tokenizer([sequences[0]], return_tensors="pt", add_special_tokens=True)
inputs = {k:v.cuda() for k,v in inputs.items()}
# res = tokenizer.decode(hf_model.generate(**inputs, max_new_tokens=10, do_sample=False)[0])
#our_res = tokenizer.decode(torch.cat([inputs["input_ids"], generate(our_model, inputs["input_ids"], 10)], dim=1)[0])
#print("hf", res)
#print("us", our_res)
print("\nBenchmarking with different sequence lengths:")
print(f"{'Sequence Length':<15} {'HF Time (ms)':<15} {'Our Time (ms)':<15} {'HF Memory (MB)':<15} {'Our Memory (MB)':<15} {'Speedup':<10}")
print("-" * 80)
for seq in sequences:
torch.cuda.empty_cache()
torch._dynamo.reset()
inputs = tokenizer([seq], return_tensors="pt", add_special_tokens=True)
seq_len = inputs["input_ids"].shape[1]
inputs = {k:v.cuda() for k, v in inputs.items()}
cat_inputs = inputs["input_ids"].repeat(32, 1)
gen_tok_warmup = generate(our_model, cat_inputs, 10000)
torch.cuda.synchronize()
t = time.perf_counter()
gen_tok = torch.cat([cat_inputs, generate(our_model, cat_inputs, 10000)], dim=1)
gen_dt = time.perf_counter() - t
print(f"generation of 32x{gen_tok.shape[1]} toks took {gen_dt}s, seq: {gen_tok.shape[1] / gen_dt} otok/s par: {gen_tok.numel()}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment