Created
December 28, 2024 16:25
-
-
Save joey00072/e670354ba1edc3f6f751ea062f8089b4 to your computer and use it in GitHub Desktop.
multi head latent attention (MLA)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://x.com/shxf0072/status/1873038335427658011 | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from dataclasses import dataclass | |
from collections import OrderedDict | |
from ohara.modules.norm import RMSNorm | |
from ohara.embedings_pos.rotatry import precompute_freqs_cis | |
from ohara.embedings_pos.rotatry import apply_rope | |
from torch import Tensor | |
from rich import print, traceback | |
traceback.install() | |
@dataclass | |
class Config(OrderedDict): | |
vocab_size: int | |
seq_len: int | |
d_model: int | |
num_heads: int = None | |
v_head_dim: int = None | |
nope_head_dim: int = None | |
rope_head_dim: int = None | |
hidden_dim: int = None | |
num_kv_heads: int = None | |
num_layers: int = 4 | |
dropout: float = 0.0 | |
bias: bool = False | |
weight_tying: bool = False | |
activation: str = "silu" | |
mlp: str = "GLU" | |
kv_lora_rank: int = None | |
q_lora_rank: int = None | |
attn_type: str = "mla" | |
def __init__(self, **kwargs): | |
super().__init__() | |
for key, value in kwargs.items(): | |
setattr(self, key, value) | |
# ====================================================================================== | |
# ||>>>> Note <<<<|| | |
# -------------------------------------------------------------------------------------- | |
# in the code they are doing different things from paper | |
# eg | |
# 1. k_rope is projection form d_model (hidden_dim) while in paper it come from compress_kv | |
# 2. while q_rope comes from compress_q (in both paper and code) | |
# 3. there are layer norm on compressed q , kv | |
# 4. norm is applied to q_nope,q_rope,k_nope and v | |
# but not to k_rope (idk why rope part of k should be normalized) | |
# 5. there is no inference merged code for mla | |
# ====================================================================================== | |
# --- MLA --- | |
class MultiHeadLatentAttention(nn.Module): | |
""" | |
Multi Head Latent Attention | |
paper: https://arxiv.org/pdf/2405.04434 | |
TLDR: | |
kv are low ranks, this verient of attention project q,k,v to low rank to save memory, | |
replace linear with lora(ish) layers | |
by joey00072 (https://github.com/joey00072) | |
""" | |
def __init__(self, config: Config): | |
super().__init__() | |
assert config.v_head_dim is not None , f"v_head_dim is not defined {config.v_head_dim=}" | |
assert config.q_lora_rank is not None , f"q_lora_rank is not defined {config.q_lora_rank=}" | |
assert config.kv_lora_rank is not None , f"kv_lora_rank is not defined {config.kv_lora_rank=}" | |
assert config.rope_head_dim is not None , f"rope_head_dim is not defined {config.rope_head_dim=}" | |
self.config = config | |
self.dim = config.d_model | |
self.num_heads = config.num_heads | |
self.v_head_dim = config.v_head_dim | |
self.nope_head_dim = config.nope_head_dim | |
self.rope_head_dim = config.rope_head_dim | |
self.q_lora_rank = config.q_lora_rank | |
self.kv_lora_rank = config.kv_lora_rank | |
self.dropout = config.dropout | |
# note: head dim of query and key if different from head dim of value | |
# (attention_dim == num_head*head_dim) > d_model in deepseekv2 | |
# this is dim between wV and wQ | |
self.value_dim = self.num_heads * self.v_head_dim | |
# this is dims between wQ and wK | |
self.nope_dim = self.num_heads * self.nope_head_dim | |
self.rope_dim = self.num_heads * self.rope_head_dim | |
# query compression | |
self.compress_q_linear = nn.Linear(self.dim, self.q_lora_rank, bias=False) # W_DQ | |
self.decompress_q_nope = nn.Linear(self.q_lora_rank, self.nope_dim, bias=False) | |
self.decompress_q_rope = nn.Linear(self.q_lora_rank, self.rope_dim, bias=False) | |
self.q_norm = RMSNorm(dim=self.q_lora_rank) | |
# key and value compression | |
self.compress_kv_linear = nn.Linear(self.dim, self.kv_lora_rank, bias=False) # W_DKV | |
self.decompress_k_nope = nn.Linear(self.kv_lora_rank, self.nope_dim, bias=False) | |
self.decompress_v_linear = nn.Linear(self.kv_lora_rank, self.value_dim, bias=False) | |
self.kv_norm = RMSNorm(dim=self.kv_lora_rank) | |
self.k_rope_linear = nn.Linear(self.dim, self.rope_head_dim , bias=False) | |
# self.rope_norm = RMSNorm(self.rope_dim) # not in deepseekv2 | |
self.proj = nn.Linear(self.value_dim , self.dim, bias=False) | |
self.res_dropout = nn.Dropout(p=config.dropout) | |
self.scale = 1/ (self.value_dim**0.5) | |
def forward(self, x: Tensor,mask: torch.Tensor, freqs_cis: Tensor): | |
batch_size, seq_len, _ = x.shape | |
compressed_q = self.compress_q_linear(x) | |
norm_q = self.q_norm(compressed_q) | |
query_nope:Tensor = self.decompress_q_nope(norm_q) | |
query_rope:Tensor = self.decompress_q_rope(norm_q) | |
compressed_kv = self.compress_kv_linear(x) | |
norm_kv = self.kv_norm(compressed_kv) | |
key_nope: Tensor = self.decompress_k_nope(norm_kv) | |
value: Tensor = self.decompress_v_linear(norm_kv) | |
key_rope:Tensor = self.k_rope_linear(x) | |
# norm_rope = self.rope_norm(key_rope) | |
query_nope = query_nope.view(batch_size, seq_len, self.num_heads, self.nope_head_dim).transpose(1,2) | |
query_rope = query_rope.view(batch_size, seq_len, self.num_heads, self.rope_head_dim).transpose(1,2) | |
key_rope = key_rope.view(batch_size, seq_len, 1, self.rope_head_dim).transpose(1,2) | |
key_nope = key_nope.view(batch_size, seq_len, self.num_heads, self.nope_head_dim).transpose(1,2) | |
value = value.view(batch_size, seq_len, self.num_heads, self.v_head_dim).transpose(1,2) | |
# query_nope = query_nope * self.scale | |
# key_nope = key_nope * self.scale | |
value = value * self.scale | |
q_rope,k_rope = apply_rope(query_rope,key_rope, cis=freqs_cis) | |
q_recombined = torch.empty((batch_size,self.num_heads,seq_len, self.rope_head_dim + self.nope_head_dim), device=x.device) | |
k_recombined = torch.empty((batch_size, self.num_heads, seq_len, self.rope_head_dim + self.nope_head_dim), device=x.device) | |
q_recombined[:,:,:,:self.nope_head_dim] = query_nope | |
q_recombined[:,:,:,self.nope_head_dim:] = q_rope | |
# k_rope = torch.repeat_interleave(k_rope, self.num_heads, dim=1) # >> you dont need to do this << | |
# 👇 broadcasting will do replication krope to all heads automagically | |
k_recombined[:,:,:,:self.nope_head_dim] = key_nope | |
k_recombined[:,:,:,self.nope_head_dim:] = k_rope | |
output = F.scaled_dot_product_attention(q_recombined, k_recombined, value, is_causal=True, dropout_p=self.dropout) | |
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * self.v_head_dim) | |
output = self.proj(output) | |
output = self.res_dropout(output) | |
return output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment