Skip to content

Instantly share code, notes, and snippets.

@SunMarc
Created May 13, 2025 09:21
Show Gist options
  • Save SunMarc/dd97f78a91178062cafce21bf25f2ac7 to your computer and use it in GitHub Desktop.
Save SunMarc/dd97f78a91178062cafce21bf25f2ac7 to your computer and use it in GitHub Desktop.
llama4-for-quantization
from transformers import AutoModelForCausalLM, AutoTokenizer, Llama4ForConditionalGeneration, Llama4Processor
from transformers.quantizers.quantizers_utils import get_module_from_name
import torch
def convert_model_for_quantization(model):
import torch.nn as nn
for name, module in model.named_modules():
module_class_name = module.__class__.__name__
if module_class_name == "Llama4TextExperts":
# Access the fused weights
gate_up_proj = module.gate_up_proj # Shape: (num_experts, hidden_size, intermediate_size * 2)
down_proj = module.down_proj # Shape: (num_experts, intermediate_size, hidden_size)
parent_module, module_name = get_module_from_name(model, name)
parent_module._modules[module_name] = SequentialLlama4TextExperts(
model.config.get_text_config(),
gate_up_proj,
down_proj
)
class SequentialLlama4TextExperts(torch.nn.ModuleList):
"""
A module that implements a compressed version of a list of expert modules.
This is specifically designed to work with Llama4TextExperts in MoE layers.
"""
def __init__(self, config, gate_up_proj, down_proj):
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP
import torch.nn as nn
# Initialize empty MLPs
super().__init__([Llama4TextMLP(config) for _ in range(gate_up_proj.shape[0])])
self.num_experts = gate_up_proj.shape[0]
# Split and assign the weights to individual MLPs
hidden_size = gate_up_proj.shape[1]
intermediate_size = down_proj.shape[1]
for expert_idx in range(self.num_experts):
# Extract weights for this expert
expert_gate_up = gate_up_proj[expert_idx] # (hidden_size, intermediate_size * 2)
expert_down = down_proj[expert_idx] # (intermediate_size, hidden_size)
# Split gate_up into gate and up projections
gate_proj = expert_gate_up[:, :intermediate_size]
up_proj = expert_gate_up[:, intermediate_size:]
# Assign weights to the MLP
self[expert_idx].gate_proj.weight.data = gate_proj.t() # Transpose to match expected shape
self[expert_idx].up_proj.weight.data = up_proj.t()
self[expert_idx].down_proj.weight.data = expert_down.t()
def forward(
self,
hidden_states: "torch.Tensor",
) -> "torch.Tensor":
hidden_states = hidden_states.reshape(self.num_experts, -1, hidden_states.shape[-1])
routed_out = torch.zeros_like(hidden_states)
for expert_idx in range(self.num_experts):
routed_out[expert_idx] = self[expert_idx](hidden_states[expert_idx])
return routed_out
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
model_output = "llama4-reshaped-all-2"
model = Llama4ForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
processor = Llama4Processor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
convert_model_for_quantization(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment