Last active
May 8, 2025 20:20
-
-
Save dvruette/f84fe35d236f17090872d77b425a2d22 to your computer and use it in GitHub Desktop.
muP Transformer with Depth scaling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
try: | |
import flash_attn | |
import flash_attn.layers.rotary | |
torch.backends.cuda.enable_flash_sdp(enabled=True) | |
has_flash_attn = True | |
except ImportError: | |
has_flash_attn = False | |
#### | |
# Pytorch implementation of https://arxiv.org/pdf/2405.15712 | |
# with alpha_L = 1.0 | |
## | |
################################################################################# | |
# Layers # | |
################################################################################# | |
class RMSNorm(nn.Module): | |
def __init__(self, dim, bias=True): | |
super().__init__() | |
self.weight = nn.Parameter(torch.zeros([dim])) | |
self.bias = nn.Parameter(torch.zeros([dim])) if bias else None | |
self.dim = dim | |
def forward(self, x): | |
tgt_shape = (1,) * (x.ndim - 1) + (self.dim,) | |
x = F.rms_norm(x.float(), [self.dim]) | |
x *= (1 + self.weight.view(*tgt_shape)) | |
x += (self.bias.view(*tgt_shape) if self.bias is not None else 0) | |
return x | |
class Rotary(nn.Module): | |
def __init__(self, dim, base=10_000, max_seq_len=512): | |
super().__init__() | |
self.dim = dim | |
self.base = base | |
self.max_seq_len = max_seq_len | |
self.precompute() | |
def precompute(self): | |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) | |
t = torch.arange(self.max_seq_len).type_as(inv_freq) | |
freqs = torch.einsum("i,j->ij", t, inv_freq.clone()) | |
emb = torch.cat((freqs, freqs), dim=-1) | |
# dims are: batch, seq_len, qkv, head, dim | |
cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1) | |
sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1) | |
# This makes the transformation on v an identity. | |
cos_cached[:,:,2,:,:].fill_(1.) | |
sin_cached[:,:,2,:,:].fill_(0.) | |
self.register_buffer('cos_cached', cos_cached) | |
self.register_buffer('sin_cached', sin_cached) | |
def forward(self, x, seq_dim=1): | |
seq_len = x.shape[seq_dim] | |
return self.cos_cached[:, :, :seq_len], self.sin_cached[:, :, :seq_len] | |
def rotate_half(x): | |
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] | |
return torch.cat((-x2, x1), dim=-1) | |
def apply_rotary_pos_emb(qkv, cos, sin): | |
if has_flash_attn: | |
cos = cos[0,:,0,0,:cos.shape[-1]//2] | |
sin = sin[0,:,0,0,:sin.shape[-1]//2] | |
return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin) | |
else: | |
return (qkv * cos) + (rotate_half(qkv) * sin) | |
class MLP(nn.Module): | |
def __init__(self, hidden_size, intermediate_size, bias=True) -> None: | |
super().__init__() | |
self.fc_1 = nn.Linear(hidden_size, intermediate_size, bias=bias) | |
self.fc_2 = nn.Linear(intermediate_size, hidden_size, bias=bias) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.fc_1(x) | |
x = F.relu(x) ** 2 | |
x = self.fc_2(x) | |
return x | |
class Attention(nn.Module): | |
def __init__(self, dim, n_heads, alpha_attn=0.5): | |
super().__init__() | |
self.n_heads = n_heads | |
self.dim = dim | |
self.alpha_attn = alpha_attn | |
self.d_head = dim // n_heads | |
self.qk_scale = 1 / self.d_head**(1.0 - self.alpha_attn) | |
self.attn_scale = 1 / self.d_head**self.alpha_attn | |
self.qk_proj = nn.Linear(dim, 2 * dim, bias=False) | |
self.v_proj = nn.Linear(dim, dim, bias=False) | |
self.out_proj = nn.Linear(dim, dim, bias=False) | |
def forward(self, x, rotary_cos_sin, seqlens=None): | |
batch_size, seq_len = x.shape[0], x.shape[1] | |
qk = self.qk_scale * self.qk_proj(x) | |
v = self.v_proj(x) | |
qkv = torch.cat((qk, v), dim=-1) | |
qkv = rearrange( | |
qkv, | |
'b s (three h d) -> b s three h d', | |
three=3, | |
h=self.n_heads, | |
) | |
cos, sin = rotary_cos_sin | |
qkv = apply_rotary_pos_emb( | |
qkv, cos.to(qkv.dtype), sin.to(qkv.dtype) | |
) | |
q, k, v = qkv[:, :, 0].transpose(1, 2), qkv[:, :, 1].transpose(1, 2), qkv[:, :, 2].transpose(1, 2) | |
x = F.scaled_dot_product_attention(q, k, v, scale=self.attn_scale, is_causal=True) | |
x = rearrange(x, 'b h s d -> b s (h d)', b=batch_size) | |
x = self.out_proj(x) | |
return x | |
################################################################################# | |
# Core Model # | |
################################################################################# | |
class EmbeddingLayer(nn.Module): | |
def __init__(self, hidden_size, vocab_dim): | |
super().__init__() | |
self.embedding = nn.Parameter(torch.empty((vocab_dim, hidden_size))) | |
nn.init.normal_(self.embedding, std=0.02) | |
def forward(self, x): | |
return self.embedding[x] | |
class Block(nn.Module): | |
def __init__(self, hidden_size, n_heads, mlp_ratio=4, resid_scale=1.0, alpha_attn=0.5): | |
super().__init__() | |
self.n_heads = n_heads | |
self.hidden_size = hidden_size | |
self.mlp_ratio = mlp_ratio | |
self.resid_scale = resid_scale | |
self.norm1 = RMSNorm(hidden_size, bias=False) | |
self.attn = Attention(hidden_size, n_heads, alpha_attn=alpha_attn) | |
self.norm2 = RMSNorm(hidden_size, bias=False) | |
self.mlp = MLP(hidden_size, hidden_size * mlp_ratio, bias=True) | |
def forward(self, x, rotary_cos_sin, seqlens=None): | |
x = x + self.resid_scale * self.attn(self.norm1(x), rotary_cos_sin, seqlens) | |
x = x + self.resid_scale * self.mlp(self.norm2(x)) | |
return x | |
class FinalLayer(nn.Module): | |
def __init__(self, hidden_size, out_channels): | |
super().__init__() | |
self.norm_final = RMSNorm(hidden_size, bias=False) | |
self.proj = nn.Linear(hidden_size, out_channels) | |
def forward(self, x): | |
x = self.norm_final(x) | |
x = self.proj(x) | |
return x | |
class KirbyTransformer(nn.Module): | |
def __init__( | |
self, | |
vocab_size: int, | |
hidden_size: int, | |
n_heads: int, | |
n_blocks: int, | |
max_seq_len: int, | |
alpha_in: float = 1.0, | |
alpha_out: float = 1.0, | |
alpha_resid: float = 1.0, | |
alpha_attn: float = 0.5, | |
init_scale: float = 0.02, | |
): | |
super().__init__() | |
self.vocab_size = vocab_size | |
self.hidden_size = hidden_size | |
self.n_heads = n_heads | |
self.n_blocks = n_blocks | |
self.max_seq_len = max_seq_len | |
self.alpha_in = alpha_in | |
self.alpha_out = alpha_out | |
self.alpha_resid = alpha_resid | |
self.alpha_attn = alpha_attn | |
self.init_scale = init_scale | |
self.vocab_embed = EmbeddingLayer(self.hidden_size, self.vocab_size) | |
self.rotary_emb = Rotary( | |
self.hidden_size // self.n_heads, | |
max_seq_len=self.max_seq_len, | |
) | |
self.resid_scale = alpha_resid / self.n_blocks | |
self.blocks = nn.ModuleList([ | |
Block( | |
hidden_size=self.hidden_size, | |
n_heads=self.n_heads, | |
resid_scale=self.resid_scale, | |
alpha_attn=alpha_attn, | |
) | |
for _ in range(self.n_blocks) | |
]) | |
self.output_layer = FinalLayer(self.hidden_size, self.vocab_size) | |
self.init_weights(init_scale) | |
def init_weights(self, init_scale=0.02): | |
# input params | |
nn.init.normal_(self.vocab_embed.embedding, std=init_scale) | |
# output params | |
nn.init.zeros_(self.output_layer.norm_final.weight) | |
nn.init.zeros_(self.output_layer.proj.weight) | |
nn.init.zeros_(self.output_layer.proj.bias) | |
# other params | |
for name, param in self.blocks.named_parameters(): | |
if param.ndim == 2: | |
# linear layers | |
init_std = init_scale / param.shape[1]**0.5 | |
if "qk_proj" in name: | |
init_std *= (self.hidden_size // self.n_heads)**(1 - self.alpha_attn) | |
nn.init.normal_(param, std=init_std) | |
elif param.ndim == 1: | |
# layer norm, biases | |
nn.init.zeros_(param) | |
else: | |
raise ValueError(f"Unknown parameter shape: {name} (shape={param.shape})") | |
def forward(self, input_ids): | |
x = self.alpha_in * self.vocab_embed(input_ids) | |
rotary_cos_sin = self.rotary_emb(x) | |
for block in self.blocks: | |
x = block(x, rotary_cos_sin, seqlens=None) | |
x = self.alpha_out * self.output_layer(x) | |
return x | |
def get_param_groups(self, base_lr, weight_decay=0.0): | |
groups = defaultdict(list) | |
for name, param in self.named_parameters(): | |
if "vocab_embed" in name or param.ndim < 2: | |
lr = base_lr | |
else: | |
lr = base_lr / self.hidden_size | |
wd = 0.0 if param.ndim == 1 else weight_decay | |
groups[(lr, wd)].append(param) | |
return [ | |
{"params": params, "lr": lr, "weight_decay": wd} | |
for (lr, wd), params in groups.items() if len(params) > 0 | |
] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment