Skip to content

Instantly share code, notes, and snippets.

@dvruette
Last active May 8, 2025 20:20
Show Gist options
  • Save dvruette/f84fe35d236f17090872d77b425a2d22 to your computer and use it in GitHub Desktop.
Save dvruette/f84fe35d236f17090872d77b425a2d22 to your computer and use it in GitHub Desktop.
muP Transformer with Depth scaling
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
try:
import flash_attn
import flash_attn.layers.rotary
torch.backends.cuda.enable_flash_sdp(enabled=True)
has_flash_attn = True
except ImportError:
has_flash_attn = False
####
# Pytorch implementation of https://arxiv.org/pdf/2405.15712
# with alpha_L = 1.0
##
#################################################################################
# Layers #
#################################################################################
class RMSNorm(nn.Module):
def __init__(self, dim, bias=True):
super().__init__()
self.weight = nn.Parameter(torch.zeros([dim]))
self.bias = nn.Parameter(torch.zeros([dim])) if bias else None
self.dim = dim
def forward(self, x):
tgt_shape = (1,) * (x.ndim - 1) + (self.dim,)
x = F.rms_norm(x.float(), [self.dim])
x *= (1 + self.weight.view(*tgt_shape))
x += (self.bias.view(*tgt_shape) if self.bias is not None else 0)
return x
class Rotary(nn.Module):
def __init__(self, dim, base=10_000, max_seq_len=512):
super().__init__()
self.dim = dim
self.base = base
self.max_seq_len = max_seq_len
self.precompute()
def precompute(self):
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
t = torch.arange(self.max_seq_len).type_as(inv_freq)
freqs = torch.einsum("i,j->ij", t, inv_freq.clone())
emb = torch.cat((freqs, freqs), dim=-1)
# dims are: batch, seq_len, qkv, head, dim
cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1)
sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1)
# This makes the transformation on v an identity.
cos_cached[:,:,2,:,:].fill_(1.)
sin_cached[:,:,2,:,:].fill_(0.)
self.register_buffer('cos_cached', cos_cached)
self.register_buffer('sin_cached', sin_cached)
def forward(self, x, seq_dim=1):
seq_len = x.shape[seq_dim]
return self.cos_cached[:, :, :seq_len], self.sin_cached[:, :, :seq_len]
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(qkv, cos, sin):
if has_flash_attn:
cos = cos[0,:,0,0,:cos.shape[-1]//2]
sin = sin[0,:,0,0,:sin.shape[-1]//2]
return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin)
else:
return (qkv * cos) + (rotate_half(qkv) * sin)
class MLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, bias=True) -> None:
super().__init__()
self.fc_1 = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.fc_2 = nn.Linear(intermediate_size, hidden_size, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc_1(x)
x = F.relu(x) ** 2
x = self.fc_2(x)
return x
class Attention(nn.Module):
def __init__(self, dim, n_heads, alpha_attn=0.5):
super().__init__()
self.n_heads = n_heads
self.dim = dim
self.alpha_attn = alpha_attn
self.d_head = dim // n_heads
self.qk_scale = 1 / self.d_head**(1.0 - self.alpha_attn)
self.attn_scale = 1 / self.d_head**self.alpha_attn
self.qk_proj = nn.Linear(dim, 2 * dim, bias=False)
self.v_proj = nn.Linear(dim, dim, bias=False)
self.out_proj = nn.Linear(dim, dim, bias=False)
def forward(self, x, rotary_cos_sin, seqlens=None):
batch_size, seq_len = x.shape[0], x.shape[1]
qk = self.qk_scale * self.qk_proj(x)
v = self.v_proj(x)
qkv = torch.cat((qk, v), dim=-1)
qkv = rearrange(
qkv,
'b s (three h d) -> b s three h d',
three=3,
h=self.n_heads,
)
cos, sin = rotary_cos_sin
qkv = apply_rotary_pos_emb(
qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)
)
q, k, v = qkv[:, :, 0].transpose(1, 2), qkv[:, :, 1].transpose(1, 2), qkv[:, :, 2].transpose(1, 2)
x = F.scaled_dot_product_attention(q, k, v, scale=self.attn_scale, is_causal=True)
x = rearrange(x, 'b h s d -> b s (h d)', b=batch_size)
x = self.out_proj(x)
return x
#################################################################################
# Core Model #
#################################################################################
class EmbeddingLayer(nn.Module):
def __init__(self, hidden_size, vocab_dim):
super().__init__()
self.embedding = nn.Parameter(torch.empty((vocab_dim, hidden_size)))
nn.init.normal_(self.embedding, std=0.02)
def forward(self, x):
return self.embedding[x]
class Block(nn.Module):
def __init__(self, hidden_size, n_heads, mlp_ratio=4, resid_scale=1.0, alpha_attn=0.5):
super().__init__()
self.n_heads = n_heads
self.hidden_size = hidden_size
self.mlp_ratio = mlp_ratio
self.resid_scale = resid_scale
self.norm1 = RMSNorm(hidden_size, bias=False)
self.attn = Attention(hidden_size, n_heads, alpha_attn=alpha_attn)
self.norm2 = RMSNorm(hidden_size, bias=False)
self.mlp = MLP(hidden_size, hidden_size * mlp_ratio, bias=True)
def forward(self, x, rotary_cos_sin, seqlens=None):
x = x + self.resid_scale * self.attn(self.norm1(x), rotary_cos_sin, seqlens)
x = x + self.resid_scale * self.mlp(self.norm2(x))
return x
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = RMSNorm(hidden_size, bias=False)
self.proj = nn.Linear(hidden_size, out_channels)
def forward(self, x):
x = self.norm_final(x)
x = self.proj(x)
return x
class KirbyTransformer(nn.Module):
def __init__(
self,
vocab_size: int,
hidden_size: int,
n_heads: int,
n_blocks: int,
max_seq_len: int,
alpha_in: float = 1.0,
alpha_out: float = 1.0,
alpha_resid: float = 1.0,
alpha_attn: float = 0.5,
init_scale: float = 0.02,
):
super().__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.n_heads = n_heads
self.n_blocks = n_blocks
self.max_seq_len = max_seq_len
self.alpha_in = alpha_in
self.alpha_out = alpha_out
self.alpha_resid = alpha_resid
self.alpha_attn = alpha_attn
self.init_scale = init_scale
self.vocab_embed = EmbeddingLayer(self.hidden_size, self.vocab_size)
self.rotary_emb = Rotary(
self.hidden_size // self.n_heads,
max_seq_len=self.max_seq_len,
)
self.resid_scale = alpha_resid / self.n_blocks
self.blocks = nn.ModuleList([
Block(
hidden_size=self.hidden_size,
n_heads=self.n_heads,
resid_scale=self.resid_scale,
alpha_attn=alpha_attn,
)
for _ in range(self.n_blocks)
])
self.output_layer = FinalLayer(self.hidden_size, self.vocab_size)
self.init_weights(init_scale)
def init_weights(self, init_scale=0.02):
# input params
nn.init.normal_(self.vocab_embed.embedding, std=init_scale)
# output params
nn.init.zeros_(self.output_layer.norm_final.weight)
nn.init.zeros_(self.output_layer.proj.weight)
nn.init.zeros_(self.output_layer.proj.bias)
# other params
for name, param in self.blocks.named_parameters():
if param.ndim == 2:
# linear layers
init_std = init_scale / param.shape[1]**0.5
if "qk_proj" in name:
init_std *= (self.hidden_size // self.n_heads)**(1 - self.alpha_attn)
nn.init.normal_(param, std=init_std)
elif param.ndim == 1:
# layer norm, biases
nn.init.zeros_(param)
else:
raise ValueError(f"Unknown parameter shape: {name} (shape={param.shape})")
def forward(self, input_ids):
x = self.alpha_in * self.vocab_embed(input_ids)
rotary_cos_sin = self.rotary_emb(x)
for block in self.blocks:
x = block(x, rotary_cos_sin, seqlens=None)
x = self.alpha_out * self.output_layer(x)
return x
def get_param_groups(self, base_lr, weight_decay=0.0):
groups = defaultdict(list)
for name, param in self.named_parameters():
if "vocab_embed" in name or param.ndim < 2:
lr = base_lr
else:
lr = base_lr / self.hidden_size
wd = 0.0 if param.ndim == 1 else weight_decay
groups[(lr, wd)].append(param)
return [
{"params": params, "lr": lr, "weight_decay": wd}
for (lr, wd), params in groups.items() if len(params) > 0
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment