Created
May 13, 2025 09:21
-
-
Save SunMarc/dd97f78a91178062cafce21bf25f2ac7 to your computer and use it in GitHub Desktop.
llama4-for-quantization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForCausalLM, AutoTokenizer, Llama4ForConditionalGeneration, Llama4Processor | |
from transformers.quantizers.quantizers_utils import get_module_from_name | |
import torch | |
def convert_model_for_quantization(model): | |
import torch.nn as nn | |
for name, module in model.named_modules(): | |
module_class_name = module.__class__.__name__ | |
if module_class_name == "Llama4TextExperts": | |
# Access the fused weights | |
gate_up_proj = module.gate_up_proj # Shape: (num_experts, hidden_size, intermediate_size * 2) | |
down_proj = module.down_proj # Shape: (num_experts, intermediate_size, hidden_size) | |
parent_module, module_name = get_module_from_name(model, name) | |
parent_module._modules[module_name] = SequentialLlama4TextExperts( | |
model.config.get_text_config(), | |
gate_up_proj, | |
down_proj | |
) | |
class SequentialLlama4TextExperts(torch.nn.ModuleList): | |
""" | |
A module that implements a compressed version of a list of expert modules. | |
This is specifically designed to work with Llama4TextExperts in MoE layers. | |
""" | |
def __init__(self, config, gate_up_proj, down_proj): | |
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP | |
import torch.nn as nn | |
# Initialize empty MLPs | |
super().__init__([Llama4TextMLP(config) for _ in range(gate_up_proj.shape[0])]) | |
self.num_experts = gate_up_proj.shape[0] | |
# Split and assign the weights to individual MLPs | |
hidden_size = gate_up_proj.shape[1] | |
intermediate_size = down_proj.shape[1] | |
for expert_idx in range(self.num_experts): | |
# Extract weights for this expert | |
expert_gate_up = gate_up_proj[expert_idx] # (hidden_size, intermediate_size * 2) | |
expert_down = down_proj[expert_idx] # (intermediate_size, hidden_size) | |
# Split gate_up into gate and up projections | |
gate_proj = expert_gate_up[:, :intermediate_size] | |
up_proj = expert_gate_up[:, intermediate_size:] | |
# Assign weights to the MLP | |
self[expert_idx].gate_proj.weight.data = gate_proj.t() # Transpose to match expected shape | |
self[expert_idx].up_proj.weight.data = up_proj.t() | |
self[expert_idx].down_proj.weight.data = expert_down.t() | |
def forward( | |
self, | |
hidden_states: "torch.Tensor", | |
) -> "torch.Tensor": | |
hidden_states = hidden_states.reshape(self.num_experts, -1, hidden_states.shape[-1]) | |
routed_out = torch.zeros_like(hidden_states) | |
for expert_idx in range(self.num_experts): | |
routed_out[expert_idx] = self[expert_idx](hidden_states[expert_idx]) | |
return routed_out | |
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" | |
model_output = "llama4-reshaped-all-2" | |
model = Llama4ForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16) | |
processor = Llama4Processor.from_pretrained(model_id) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
convert_model_for_quantization(model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment