Last active
August 22, 2025 07:33
-
-
Save scturtle/d165c6c5deac8d6d3a2b3ee0f644014b to your computer and use it in GitHub Desktop.
gemma3 270m
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from safetensors import safe_open | |
from tokenizers import Tokenizer | |
CFG = { | |
"vocab_size": 262_144, | |
"context_length": 32_768, | |
"emb_dim": 640, | |
"n_heads": 4, | |
"n_layers": 18, | |
"hidden_dim": 2048, | |
"head_dim": 256, | |
"qk_norm": True, | |
"n_kv_groups": 1, | |
"rope_local_base": 10_000.0, | |
"rope_base": 1_000_000.0, | |
"sliding_window": 512, | |
"layer_types": (["sliding_attention"] * 5 + ["full_attention"]) * 3, | |
"dtype": torch.bfloat16, | |
"query_pre_attn_scalar": 256, | |
} | |
class FeedForward(nn.Module): | |
def __init__(self): | |
super().__init__() | |
emb_dim, hidden_dim, dtype = CFG["emb_dim"], CFG["hidden_dim"], CFG["dtype"] | |
self.gate_proj = nn.Linear( | |
emb_dim, hidden_dim, dtype=dtype, bias=False, device="meta" | |
) | |
self.up_proj = nn.Linear( | |
emb_dim, hidden_dim, dtype=dtype, bias=False, device="meta" | |
) | |
self.down_proj = nn.Linear( | |
hidden_dim, emb_dim, dtype=dtype, bias=False, device="meta" | |
) | |
def forward(self, x): | |
return self.down_proj( | |
nn.functional.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x) | |
) | |
class RMSNorm(nn.Module): | |
def __init__(self, emb_dim, eps=1e-6): | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.zeros(emb_dim, dtype=CFG["dtype"])) | |
def forward(self, x): | |
# Match HF Gemma3: compute norm in float32, then scale by (1 + w) | |
input_dtype = x.dtype | |
x_f = x.float() | |
var = x_f.pow(2).mean(dim=-1, keepdim=True) | |
x_norm = x_f * torch.rsqrt(var + self.eps) | |
out = x_norm * (1.0 + self.weight.float()) | |
return out.to(input_dtype) | |
def compute_rope_params( | |
head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32 | |
): | |
inv_freq = 1.0 / ( | |
theta_base | |
** ( | |
torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() | |
/ head_dim | |
) | |
) | |
positions = torch.arange(context_length, dtype=dtype) | |
angles = positions[:, None] * inv_freq[None, :] | |
angles = torch.cat([angles, angles], dim=1) | |
cos = torch.cos(angles) | |
sin = torch.sin(angles) | |
return cos, sin | |
def apply_rope(x, cos, sin, offset=0): | |
_, _, seq_len, head_dim = x.shape | |
x1 = x[..., : head_dim // 2] | |
x2 = x[..., head_dim // 2 :] | |
cos = cos[offset : offset + seq_len, :].unsqueeze(0).unsqueeze(0) | |
sin = sin[offset : offset + seq_len, :].unsqueeze(0).unsqueeze(0) | |
rotated = torch.cat((-x2, x1), dim=-1) | |
x_rotated = (x * cos) + (rotated * sin) | |
return x_rotated.to(dtype=x.dtype) | |
class GroupedQueryAttention(nn.Module): | |
def __init__( | |
self, | |
layer_idx, | |
): | |
super().__init__() | |
self.layer_idx = layer_idx | |
num_heads, head_dim, num_kv_groups, emb_dim, dtype = ( | |
CFG["n_heads"], | |
CFG["head_dim"], | |
CFG["n_kv_groups"], | |
CFG["emb_dim"], | |
CFG["dtype"], | |
) | |
assert num_heads % num_kv_groups == 0 | |
self.num_heads = num_heads | |
self.num_kv_groups = num_kv_groups | |
self.group_size = num_heads // num_kv_groups | |
self.head_dim = head_dim | |
self.d_out = num_heads * head_dim | |
self.q_proj = nn.Linear( | |
emb_dim, self.d_out, bias=False, dtype=dtype, device="meta" | |
) | |
self.k_proj = nn.Linear( | |
emb_dim, num_kv_groups * head_dim, bias=False, dtype=dtype, device="meta" | |
) | |
self.v_proj = nn.Linear( | |
emb_dim, num_kv_groups * head_dim, bias=False, dtype=dtype, device="meta" | |
) | |
self.o_proj = nn.Linear( | |
self.d_out, emb_dim, bias=False, dtype=dtype, device="meta" | |
) | |
self.q_norm = RMSNorm(head_dim) | |
self.k_norm = RMSNorm(head_dim) | |
self.scaling = CFG["query_pre_attn_scalar"] ** -0.5 | |
def forward(self, x, mask, cos, sin, start_pos, cache): | |
B, L, _ = x.shape | |
queries = self.q_proj(x) | |
keys = self.k_proj(x) | |
values = self.v_proj(x) | |
queries = queries.view(B, L, self.num_heads, self.head_dim).transpose(1, 2) | |
keys = keys.view(B, L, self.num_kv_groups, self.head_dim).transpose(1, 2) | |
values = values.view(B, L, self.num_kv_groups, self.head_dim).transpose(1, 2) | |
queries = self.q_norm(queries) | |
keys = self.k_norm(keys) | |
queries = apply_rope(queries, cos, sin, offset=start_pos) | |
keys = apply_rope(keys, cos, sin, offset=start_pos) | |
prev_k, prev_v = cache.get(self.layer_idx) | |
if prev_k is not None: | |
keys = torch.cat([prev_k, keys], dim=2) | |
values = torch.cat([prev_v, values], dim=2) | |
cache.update(self.layer_idx, (keys, values)) | |
res = F.scaled_dot_product_attention( | |
queries, | |
keys, | |
values, | |
attn_mask=mask, | |
scale=self.scaling, | |
enable_gqa=True, | |
) | |
return self.o_proj(res.transpose(1, 2).contiguous().reshape(B, L, -1)) | |
class TransformerBlock(nn.Module): | |
def __init__(self, layer_idx, attn_type): | |
super().__init__() | |
self.attn_type = attn_type | |
self.sliding_window = CFG["sliding_window"] | |
self.self_attn = GroupedQueryAttention( | |
layer_idx=layer_idx, | |
) | |
self.mlp = FeedForward() | |
emb_dim = CFG["emb_dim"] | |
self.input_layernorm = RMSNorm(emb_dim) | |
self.post_attention_layernorm = RMSNorm(emb_dim) | |
self.pre_feedforward_layernorm = RMSNorm(emb_dim) | |
self.post_feedforward_layernorm = RMSNorm(emb_dim) | |
def forward( | |
self, | |
x, | |
mask_global, | |
mask_local, | |
cos_global, | |
sin_global, | |
cos_local, | |
sin_local, | |
start_pos, | |
cache, | |
): | |
if self.attn_type == "sliding_attention": | |
attn_mask = mask_local | |
cos = cos_local | |
sin = sin_local | |
else: | |
attn_mask = mask_global | |
cos = cos_global | |
sin = sin_global | |
x = x + self.post_attention_layernorm( | |
self.self_attn( | |
self.input_layernorm(x), attn_mask, cos, sin, start_pos, cache | |
) | |
) | |
x = x + self.post_feedforward_layernorm( | |
self.mlp(self.pre_feedforward_layernorm(x)) | |
) | |
return x | |
class Gemma3Model(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.embed_tokens = nn.Embedding( | |
CFG["vocab_size"], CFG["emb_dim"], dtype=CFG["dtype"], device="meta" | |
) | |
self.layers = nn.ModuleList( | |
[ | |
TransformerBlock(i, attn_type) | |
for i, attn_type in enumerate(CFG["layer_types"]) | |
] | |
) | |
self.norm = RMSNorm(CFG["emb_dim"]) | |
self.out_head = nn.Linear( | |
CFG["emb_dim"], | |
CFG["vocab_size"], | |
bias=False, | |
dtype=CFG["dtype"], | |
device="meta", | |
) | |
cos_local, sin_local = compute_rope_params( | |
head_dim=CFG["head_dim"], | |
theta_base=CFG["rope_local_base"], | |
context_length=CFG["context_length"], | |
dtype=torch.float32, | |
) | |
cos_global, sin_global = compute_rope_params( | |
head_dim=CFG["head_dim"], | |
theta_base=CFG["rope_base"], | |
context_length=CFG["context_length"], | |
dtype=torch.float32, | |
) | |
self.register_buffer("cos_local", cos_local, persistent=False) | |
self.register_buffer("sin_local", sin_local, persistent=False) | |
self.register_buffer("cos_global", cos_global, persistent=False) | |
self.register_buffer("sin_global", sin_global, persistent=False) | |
def _create_masks(self, device, pos_start, pos_end): | |
q_indices = torch.arange(pos_start, pos_end, device=device)[:, None] | |
k_indices = torch.arange(pos_end, device=device)[None, :] | |
attend_global = k_indices <= q_indices | |
is_in_window = k_indices > q_indices - CFG["sliding_window"] | |
attend_local = attend_global & is_in_window | |
return attend_global[None, None, :, :], attend_local[None, None, :, :] | |
def forward(self, input_ids, cache): | |
_, seq_len = input_ids.shape | |
x = self.embed_tokens(input_ids) * (CFG["emb_dim"] ** 0.5) | |
pos_start = cache.position | |
pos_end = pos_start + seq_len | |
cache.position = pos_end | |
mask_global, mask_local = self._create_masks( | |
device=x.device, pos_start=pos_start, pos_end=pos_end | |
) | |
for block in self.layers: | |
x = block( | |
x, | |
mask_global=mask_global, | |
mask_local=mask_local, | |
cos_global=self.cos_global, | |
sin_global=self.sin_global, | |
cos_local=self.cos_local, | |
sin_local=self.sin_local, | |
start_pos=pos_start, | |
cache=cache, | |
) | |
return self.out_head(self.norm(x)) | |
class KVCache: | |
def __init__(self): | |
from collections import defaultdict | |
self.cache = defaultdict(lambda: (None, None)) | |
self.position = 0 | |
def get(self, layer_idx): | |
return self.cache[layer_idx] | |
def update(self, layer_idx, value): | |
self.cache[layer_idx] = value | |
def load_weights_into_gemma(model, dir: Path): | |
state_dict = {} | |
for file in dir.glob("*.safetensors"): | |
with safe_open(file, "pt", "cpu") as f: | |
for weight_name in f.keys(): | |
loaded_tensor = f.get_tensor(weight_name) | |
param = model.get_parameter(weight_name[len("model.") :]) | |
assert param.dtype == loaded_tensor.dtype | |
assert param.shape == loaded_tensor.shape | |
state_dict[weight_name[len("model.") :]] = loaded_tensor | |
state_dict["out_head.weight"] = state_dict["embed_tokens.weight"] | |
model.load_state_dict(state_dict, assign=True) | |
class GemmaTokenizer: | |
def __init__(self, tok_file): | |
self._tok = Tokenizer.from_file(str(tok_file)) | |
def encode(self, text: str) -> list[int]: | |
return self._tok.encode(text).ids | |
def decode(self, ids: list[int]) -> str: | |
return self._tok.decode(ids, skip_special_tokens=False) | |
def apply_chat_template(user_text): | |
return f"<start_of_turn>user\n{user_text}<end_of_turn>\n<start_of_turn>model\n" | |
model = Gemma3Model() | |
# hf download google/gemma-3-270m-it --local-dir gemma-3-270m-it | |
load_weights_into_gemma(model, Path("gemma-3-270m")) | |
tokenizer = GemmaTokenizer(Path("gemma-3-270m") / "tokenizer.json") | |
prompt = "Give me a short introduction to large language models." | |
prompt = apply_chat_template(prompt) | |
token_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0) | |
print(tokenizer.decode(token_ids[0].tolist()), end="", flush=True) | |
model.eval() | |
cache = KVCache() | |
eos_token_id = tokenizer.encode("<end_of_turn>")[-1] | |
with torch.no_grad(): | |
for _ in range(150): | |
out = model(token_ids, cache) | |
next_token = torch.argmax(out[:, -1], dim=-1, keepdim=True) | |
if eos_token_id is not None and torch.all(next_token == eos_token_id): | |
break | |
print(tokenizer.decode(next_token[0].tolist()), end="", flush=True) | |
token_ids = next_token |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment