Skip to content

Instantly share code, notes, and snippets.

@scturtle
Last active August 22, 2025 07:33
Show Gist options
  • Save scturtle/d165c6c5deac8d6d3a2b3ee0f644014b to your computer and use it in GitHub Desktop.
Save scturtle/d165c6c5deac8d6d3a2b3ee0f644014b to your computer and use it in GitHub Desktop.
gemma3 270m
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors import safe_open
from tokenizers import Tokenizer
CFG = {
"vocab_size": 262_144,
"context_length": 32_768,
"emb_dim": 640,
"n_heads": 4,
"n_layers": 18,
"hidden_dim": 2048,
"head_dim": 256,
"qk_norm": True,
"n_kv_groups": 1,
"rope_local_base": 10_000.0,
"rope_base": 1_000_000.0,
"sliding_window": 512,
"layer_types": (["sliding_attention"] * 5 + ["full_attention"]) * 3,
"dtype": torch.bfloat16,
"query_pre_attn_scalar": 256,
}
class FeedForward(nn.Module):
def __init__(self):
super().__init__()
emb_dim, hidden_dim, dtype = CFG["emb_dim"], CFG["hidden_dim"], CFG["dtype"]
self.gate_proj = nn.Linear(
emb_dim, hidden_dim, dtype=dtype, bias=False, device="meta"
)
self.up_proj = nn.Linear(
emb_dim, hidden_dim, dtype=dtype, bias=False, device="meta"
)
self.down_proj = nn.Linear(
hidden_dim, emb_dim, dtype=dtype, bias=False, device="meta"
)
def forward(self, x):
return self.down_proj(
nn.functional.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x)
)
class RMSNorm(nn.Module):
def __init__(self, emb_dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(emb_dim, dtype=CFG["dtype"]))
def forward(self, x):
# Match HF Gemma3: compute norm in float32, then scale by (1 + w)
input_dtype = x.dtype
x_f = x.float()
var = x_f.pow(2).mean(dim=-1, keepdim=True)
x_norm = x_f * torch.rsqrt(var + self.eps)
out = x_norm * (1.0 + self.weight.float())
return out.to(input_dtype)
def compute_rope_params(
head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32
):
inv_freq = 1.0 / (
theta_base
** (
torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float()
/ head_dim
)
)
positions = torch.arange(context_length, dtype=dtype)
angles = positions[:, None] * inv_freq[None, :]
angles = torch.cat([angles, angles], dim=1)
cos = torch.cos(angles)
sin = torch.sin(angles)
return cos, sin
def apply_rope(x, cos, sin, offset=0):
_, _, seq_len, head_dim = x.shape
x1 = x[..., : head_dim // 2]
x2 = x[..., head_dim // 2 :]
cos = cos[offset : offset + seq_len, :].unsqueeze(0).unsqueeze(0)
sin = sin[offset : offset + seq_len, :].unsqueeze(0).unsqueeze(0)
rotated = torch.cat((-x2, x1), dim=-1)
x_rotated = (x * cos) + (rotated * sin)
return x_rotated.to(dtype=x.dtype)
class GroupedQueryAttention(nn.Module):
def __init__(
self,
layer_idx,
):
super().__init__()
self.layer_idx = layer_idx
num_heads, head_dim, num_kv_groups, emb_dim, dtype = (
CFG["n_heads"],
CFG["head_dim"],
CFG["n_kv_groups"],
CFG["emb_dim"],
CFG["dtype"],
)
assert num_heads % num_kv_groups == 0
self.num_heads = num_heads
self.num_kv_groups = num_kv_groups
self.group_size = num_heads // num_kv_groups
self.head_dim = head_dim
self.d_out = num_heads * head_dim
self.q_proj = nn.Linear(
emb_dim, self.d_out, bias=False, dtype=dtype, device="meta"
)
self.k_proj = nn.Linear(
emb_dim, num_kv_groups * head_dim, bias=False, dtype=dtype, device="meta"
)
self.v_proj = nn.Linear(
emb_dim, num_kv_groups * head_dim, bias=False, dtype=dtype, device="meta"
)
self.o_proj = nn.Linear(
self.d_out, emb_dim, bias=False, dtype=dtype, device="meta"
)
self.q_norm = RMSNorm(head_dim)
self.k_norm = RMSNorm(head_dim)
self.scaling = CFG["query_pre_attn_scalar"] ** -0.5
def forward(self, x, mask, cos, sin, start_pos, cache):
B, L, _ = x.shape
queries = self.q_proj(x)
keys = self.k_proj(x)
values = self.v_proj(x)
queries = queries.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(B, L, self.num_kv_groups, self.head_dim).transpose(1, 2)
values = values.view(B, L, self.num_kv_groups, self.head_dim).transpose(1, 2)
queries = self.q_norm(queries)
keys = self.k_norm(keys)
queries = apply_rope(queries, cos, sin, offset=start_pos)
keys = apply_rope(keys, cos, sin, offset=start_pos)
prev_k, prev_v = cache.get(self.layer_idx)
if prev_k is not None:
keys = torch.cat([prev_k, keys], dim=2)
values = torch.cat([prev_v, values], dim=2)
cache.update(self.layer_idx, (keys, values))
res = F.scaled_dot_product_attention(
queries,
keys,
values,
attn_mask=mask,
scale=self.scaling,
enable_gqa=True,
)
return self.o_proj(res.transpose(1, 2).contiguous().reshape(B, L, -1))
class TransformerBlock(nn.Module):
def __init__(self, layer_idx, attn_type):
super().__init__()
self.attn_type = attn_type
self.sliding_window = CFG["sliding_window"]
self.self_attn = GroupedQueryAttention(
layer_idx=layer_idx,
)
self.mlp = FeedForward()
emb_dim = CFG["emb_dim"]
self.input_layernorm = RMSNorm(emb_dim)
self.post_attention_layernorm = RMSNorm(emb_dim)
self.pre_feedforward_layernorm = RMSNorm(emb_dim)
self.post_feedforward_layernorm = RMSNorm(emb_dim)
def forward(
self,
x,
mask_global,
mask_local,
cos_global,
sin_global,
cos_local,
sin_local,
start_pos,
cache,
):
if self.attn_type == "sliding_attention":
attn_mask = mask_local
cos = cos_local
sin = sin_local
else:
attn_mask = mask_global
cos = cos_global
sin = sin_global
x = x + self.post_attention_layernorm(
self.self_attn(
self.input_layernorm(x), attn_mask, cos, sin, start_pos, cache
)
)
x = x + self.post_feedforward_layernorm(
self.mlp(self.pre_feedforward_layernorm(x))
)
return x
class Gemma3Model(nn.Module):
def __init__(self):
super().__init__()
self.embed_tokens = nn.Embedding(
CFG["vocab_size"], CFG["emb_dim"], dtype=CFG["dtype"], device="meta"
)
self.layers = nn.ModuleList(
[
TransformerBlock(i, attn_type)
for i, attn_type in enumerate(CFG["layer_types"])
]
)
self.norm = RMSNorm(CFG["emb_dim"])
self.out_head = nn.Linear(
CFG["emb_dim"],
CFG["vocab_size"],
bias=False,
dtype=CFG["dtype"],
device="meta",
)
cos_local, sin_local = compute_rope_params(
head_dim=CFG["head_dim"],
theta_base=CFG["rope_local_base"],
context_length=CFG["context_length"],
dtype=torch.float32,
)
cos_global, sin_global = compute_rope_params(
head_dim=CFG["head_dim"],
theta_base=CFG["rope_base"],
context_length=CFG["context_length"],
dtype=torch.float32,
)
self.register_buffer("cos_local", cos_local, persistent=False)
self.register_buffer("sin_local", sin_local, persistent=False)
self.register_buffer("cos_global", cos_global, persistent=False)
self.register_buffer("sin_global", sin_global, persistent=False)
def _create_masks(self, device, pos_start, pos_end):
q_indices = torch.arange(pos_start, pos_end, device=device)[:, None]
k_indices = torch.arange(pos_end, device=device)[None, :]
attend_global = k_indices <= q_indices
is_in_window = k_indices > q_indices - CFG["sliding_window"]
attend_local = attend_global & is_in_window
return attend_global[None, None, :, :], attend_local[None, None, :, :]
def forward(self, input_ids, cache):
_, seq_len = input_ids.shape
x = self.embed_tokens(input_ids) * (CFG["emb_dim"] ** 0.5)
pos_start = cache.position
pos_end = pos_start + seq_len
cache.position = pos_end
mask_global, mask_local = self._create_masks(
device=x.device, pos_start=pos_start, pos_end=pos_end
)
for block in self.layers:
x = block(
x,
mask_global=mask_global,
mask_local=mask_local,
cos_global=self.cos_global,
sin_global=self.sin_global,
cos_local=self.cos_local,
sin_local=self.sin_local,
start_pos=pos_start,
cache=cache,
)
return self.out_head(self.norm(x))
class KVCache:
def __init__(self):
from collections import defaultdict
self.cache = defaultdict(lambda: (None, None))
self.position = 0
def get(self, layer_idx):
return self.cache[layer_idx]
def update(self, layer_idx, value):
self.cache[layer_idx] = value
def load_weights_into_gemma(model, dir: Path):
state_dict = {}
for file in dir.glob("*.safetensors"):
with safe_open(file, "pt", "cpu") as f:
for weight_name in f.keys():
loaded_tensor = f.get_tensor(weight_name)
param = model.get_parameter(weight_name[len("model.") :])
assert param.dtype == loaded_tensor.dtype
assert param.shape == loaded_tensor.shape
state_dict[weight_name[len("model.") :]] = loaded_tensor
state_dict["out_head.weight"] = state_dict["embed_tokens.weight"]
model.load_state_dict(state_dict, assign=True)
class GemmaTokenizer:
def __init__(self, tok_file):
self._tok = Tokenizer.from_file(str(tok_file))
def encode(self, text: str) -> list[int]:
return self._tok.encode(text).ids
def decode(self, ids: list[int]) -> str:
return self._tok.decode(ids, skip_special_tokens=False)
def apply_chat_template(user_text):
return f"<start_of_turn>user\n{user_text}<end_of_turn>\n<start_of_turn>model\n"
model = Gemma3Model()
# hf download google/gemma-3-270m-it --local-dir gemma-3-270m-it
load_weights_into_gemma(model, Path("gemma-3-270m"))
tokenizer = GemmaTokenizer(Path("gemma-3-270m") / "tokenizer.json")
prompt = "Give me a short introduction to large language models."
prompt = apply_chat_template(prompt)
token_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
print(tokenizer.decode(token_ids[0].tolist()), end="", flush=True)
model.eval()
cache = KVCache()
eos_token_id = tokenizer.encode("<end_of_turn>")[-1]
with torch.no_grad():
for _ in range(150):
out = model(token_ids, cache)
next_token = torch.argmax(out[:, -1], dim=-1, keepdim=True)
if eos_token_id is not None and torch.all(next_token == eos_token_id):
break
print(tokenizer.decode(next_token[0].tolist()), end="", flush=True)
token_ids = next_token
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment