Created
October 3, 2024 20:58
-
-
Save jerryzh168/692ff83735d4ca298c1aad2424b2c225 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py | |
| index b63aaf1..9c268ab 100644 | |
| --- a/python/sglang/srt/models/llama.py | |
| +++ b/python/sglang/srt/models/llama.py | |
| @@ -18,6 +18,7 @@ limitations under the License. | |
| """Inference-only LLaMA model compatible with HuggingFace weights.""" | |
| from typing import Any, Dict, Iterable, Optional, Tuple | |
| +from torch.nn.parameter import Parameter | |
| import torch | |
| from torch import nn | |
| @@ -44,6 +45,33 @@ from sglang.srt.layers.radix_attention import RadixAttention | |
| from sglang.srt.layers.torchao_utils import apply_torchao_config_ | |
| from sglang.srt.managers.schedule_batch import global_server_args_dict | |
| from sglang.srt.model_executor.forward_batch_info import InputMetadata | |
| +import types | |
| + | |
| +def gate_up_proj_weight_loader( | |
| + self, | |
| + param: Parameter, | |
| + loaded_weight: torch.Tensor, | |
| + loaded_shard_id: Optional[int] = None, | |
| +): | |
| + if loaded_shard_id is None: | |
| + shard_offsets: List[Tuple[int, int, int]] = [] | |
| + for i, output_size in enumerate(self.output_sizes): | |
| + shard_offsets.append((i, current_shard_offset, output_size)) | |
| + current_shard_offset += output_size | |
| + for shard_id, shard_offset, shard_size in shard_offsets: | |
| + loaded_weight_shard = loaded_weight.narrow( | |
| + output_dim, shard_offset, shard_size | |
| + ) | |
| + self.weight_loader(param, loaded_weight_shard, shard_id) | |
| + else: | |
| + assert loaded_shard_id < len(self.output_sizes) | |
| + param_data = param.data | |
| + shard_size = loaded_weight.shape[0] | |
| + shard_offset = loaded_shard_id * shard_size | |
| + param_data = param_data.narrow(0, shard_offset, shard_size) | |
| + assert param_data.shape == loaded_weight.shape | |
| + param_data.copy_(loaded_weight) | |
| + return | |
| class LlamaMLP(nn.Module): | |
| @@ -56,20 +84,29 @@ class LlamaMLP(nn.Module): | |
| prefix: str = "", | |
| ) -> None: | |
| super().__init__() | |
| - self.gate_up_proj = MergedColumnParallelLinear( | |
| + # self.gate_up_proj = MergedColumnParallelLinear( | |
| + # hidden_size, | |
| + # [intermediate_size] * 2, | |
| + # bias=False, | |
| + # quant_config=quant_config, | |
| + # prefix=f"{prefix}.gate_up_proj", | |
| + # ) | |
| + self.gate_up_proj = torch.nn.Linear( | |
| hidden_size, | |
| - [intermediate_size] * 2, | |
| + intermediate_size * 2, | |
| bias=False, | |
| - quant_config=quant_config, | |
| - prefix=f"{prefix}.gate_up_proj", | |
| - ) | |
| - self.down_proj = RowParallelLinear( | |
| - intermediate_size, | |
| - hidden_size, | |
| - bias=False, | |
| - quant_config=quant_config, | |
| - prefix=f"{prefix}.down_proj", | |
| ) | |
| + self.gate_up_proj.output_sizes = [intermediate_size] * 2 | |
| + self.gate_up_proj.weight_loader = types.MethodType(gate_up_proj_weight_loader, self.gate_up_proj) | |
| + self.gate_up_proj.weight.weight_loader = self.gate_up_proj.weight_loader | |
| + # self.down_proj = RowParallelLinear( | |
| + # intermediate_size, | |
| + # hidden_size, | |
| + # bias=False, | |
| + # quant_config=quant_config, | |
| + # prefix=f"{prefix}.down_proj", | |
| + # ) | |
| + self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False) | |
| if hidden_act != "silu": | |
| raise ValueError( | |
| f"Unsupported activation: {hidden_act}. " | |
| @@ -78,12 +115,65 @@ class LlamaMLP(nn.Module): | |
| self.act_fn = SiluAndMul() | |
| def forward(self, x): | |
| - gate_up, _ = self.gate_up_proj(x) | |
| + gate_up = self.gate_up_proj(x) | |
| x = self.act_fn(gate_up) | |
| - x, _ = self.down_proj(x) | |
| + x = self.down_proj(x) | |
| return x | |
| +def _get_shard_offset_mapping(self, loaded_shard_id: str): | |
| + shard_offset_mapping = { | |
| + "q": 0, | |
| + "k": self.num_heads * self.head_size, | |
| + "v": (self.num_heads + self.num_kv_heads) * self.head_size, | |
| + "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size, | |
| + } | |
| + return shard_offset_mapping.get(loaded_shard_id) | |
| + | |
| +def _get_shard_size_mapping(self, loaded_shard_id: str): | |
| + shard_size_mapping = { | |
| + "q": self.num_heads * self.head_size, | |
| + "k": self.num_kv_heads * self.head_size, | |
| + "v": self.num_kv_heads * self.head_size, | |
| + } | |
| + return shard_size_mapping.get(loaded_shard_id) | |
| + | |
| +def qkv_proj_weight_loader( | |
| + self, | |
| + param: Parameter, | |
| + loaded_weight: torch.Tensor, | |
| + loaded_shard_id: Optional[str] = None, | |
| +): | |
| + if loaded_shard_id is None: | |
| + shard_offsets = [ | |
| + # (shard_id, shard_offset, shard_size) | |
| + ("q", 0, self.total_num_heads * self.head_size), | |
| + ( | |
| + "k", | |
| + self.total_num_heads * self.head_size, | |
| + self.total_num_kv_heads * self.head_size, | |
| + ), | |
| + ( | |
| + "v", | |
| + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, | |
| + self.total_num_kv_heads * self.head_size, | |
| + ), | |
| + ] | |
| + for shard_id, shard_offset, shard_size in shard_offsets: | |
| + loaded_weight_shard = loaded_weight.narrow( | |
| + param.output_dim, shard_offset, shard_size | |
| + ) | |
| + self.weight_loader(param, loaded_weight_shard, shard_id) | |
| + else: | |
| + shard_offset = self._get_shard_offset_mapping(loaded_shard_id) | |
| + shard_size = self._get_shard_size_mapping(loaded_shard_id) | |
| + param_data = param.data | |
| + param_data = param_data.narrow(0, shard_offset, shard_size) | |
| + assert param_data.shape == loaded_weight.shape | |
| + param_data.copy_(loaded_weight) | |
| + return | |
| + | |
| + | |
| class LlamaAttention(nn.Module): | |
| def __init__( | |
| self, | |
| @@ -125,23 +215,42 @@ class LlamaAttention(nn.Module): | |
| self.rope_theta = rope_theta | |
| self.max_position_embeddings = max_position_embeddings | |
| - self.qkv_proj = QKVParallelLinear( | |
| + # self.qkv_proj = QKVParallelLinear( | |
| + # hidden_size, | |
| + # self.head_dim, | |
| + # self.total_num_heads, | |
| + # self.total_num_kv_heads, | |
| + # bias=False, | |
| + # quant_config=quant_config, | |
| + # prefix=f"{prefix}.qkv_proj", | |
| + # ) | |
| + self.qkv_proj = torch.nn.Linear( | |
| hidden_size, | |
| - self.head_dim, | |
| - self.total_num_heads, | |
| - self.total_num_kv_heads, | |
| + (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, | |
| bias=False, | |
| - quant_config=quant_config, | |
| - prefix=f"{prefix}.qkv_proj", | |
| ) | |
| - self.o_proj = RowParallelLinear( | |
| + self.qkv_proj.total_num_heads = self.total_num_heads | |
| + self.qkv_proj.head_size = self.head_dim | |
| + self.qkv_proj.total_num_kv_heads = self.total_num_kv_heads | |
| + self.qkv_proj.num_heads = self.total_num_heads | |
| + self.qkv_proj.num_kv_heads = self.total_num_kv_heads | |
| + self.qkv_proj.weight_loader = types.MethodType(qkv_proj_weight_loader, self.qkv_proj) | |
| + self.qkv_proj._get_shard_offset_mapping = types.MethodType(_get_shard_offset_mapping, self.qkv_proj) | |
| + self.qkv_proj._get_shard_size_mapping = types.MethodType(_get_shard_size_mapping, self.qkv_proj) | |
| + self.qkv_proj.weight.weight_loader = self.qkv_proj.weight_loader | |
| + self.qkv_proj.weight.output_dim = 0 | |
| + # self.o_proj = RowParallelLinear( | |
| + # self.total_num_heads * self.head_dim, | |
| + # hidden_size, | |
| + # bias=False, | |
| + # quant_config=quant_config, | |
| + # prefix=f"{prefix}.o_proj", | |
| + # ) | |
| + self.o_proj = torch.nn.Linear( | |
| self.total_num_heads * self.head_dim, | |
| hidden_size, | |
| bias=False, | |
| - quant_config=quant_config, | |
| - prefix=f"{prefix}.o_proj", | |
| ) | |
| - | |
| self.rotary_emb = get_rope( | |
| self.head_dim, | |
| rotary_dim=self.head_dim, | |
| @@ -164,11 +273,11 @@ class LlamaAttention(nn.Module): | |
| hidden_states: torch.Tensor, | |
| input_metadata: InputMetadata, | |
| ) -> torch.Tensor: | |
| - qkv, _ = self.qkv_proj(hidden_states) | |
| + qkv = self.qkv_proj(hidden_states) | |
| q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) | |
| q, k = self.rotary_emb(positions, q, k) | |
| attn_output = self.attn(q, k, v, input_metadata) | |
| - output, _ = self.o_proj(attn_output) | |
| + output = self.o_proj(attn_output) | |
| return output | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment