Skip to content

Instantly share code, notes, and snippets.

@jerryzh168
Created October 3, 2024 20:58
Show Gist options
  • Save jerryzh168/692ff83735d4ca298c1aad2424b2c225 to your computer and use it in GitHub Desktop.
Save jerryzh168/692ff83735d4ca298c1aad2424b2c225 to your computer and use it in GitHub Desktop.
diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py
index b63aaf1..9c268ab 100644
--- a/python/sglang/srt/models/llama.py
+++ b/python/sglang/srt/models/llama.py
@@ -18,6 +18,7 @@ limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Tuple
+from torch.nn.parameter import Parameter
import torch
from torch import nn
@@ -44,6 +45,33 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
+import types
+
+def gate_up_proj_weight_loader(
+ self,
+ param: Parameter,
+ loaded_weight: torch.Tensor,
+ loaded_shard_id: Optional[int] = None,
+):
+ if loaded_shard_id is None:
+ shard_offsets: List[Tuple[int, int, int]] = []
+ for i, output_size in enumerate(self.output_sizes):
+ shard_offsets.append((i, current_shard_offset, output_size))
+ current_shard_offset += output_size
+ for shard_id, shard_offset, shard_size in shard_offsets:
+ loaded_weight_shard = loaded_weight.narrow(
+ output_dim, shard_offset, shard_size
+ )
+ self.weight_loader(param, loaded_weight_shard, shard_id)
+ else:
+ assert loaded_shard_id < len(self.output_sizes)
+ param_data = param.data
+ shard_size = loaded_weight.shape[0]
+ shard_offset = loaded_shard_id * shard_size
+ param_data = param_data.narrow(0, shard_offset, shard_size)
+ assert param_data.shape == loaded_weight.shape
+ param_data.copy_(loaded_weight)
+ return
class LlamaMLP(nn.Module):
@@ -56,20 +84,29 @@ class LlamaMLP(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
- self.gate_up_proj = MergedColumnParallelLinear(
+ # self.gate_up_proj = MergedColumnParallelLinear(
+ # hidden_size,
+ # [intermediate_size] * 2,
+ # bias=False,
+ # quant_config=quant_config,
+ # prefix=f"{prefix}.gate_up_proj",
+ # )
+ self.gate_up_proj = torch.nn.Linear(
hidden_size,
- [intermediate_size] * 2,
+ intermediate_size * 2,
bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.gate_up_proj",
- )
- self.down_proj = RowParallelLinear(
- intermediate_size,
- hidden_size,
- bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.down_proj",
)
+ self.gate_up_proj.output_sizes = [intermediate_size] * 2
+ self.gate_up_proj.weight_loader = types.MethodType(gate_up_proj_weight_loader, self.gate_up_proj)
+ self.gate_up_proj.weight.weight_loader = self.gate_up_proj.weight_loader
+ # self.down_proj = RowParallelLinear(
+ # intermediate_size,
+ # hidden_size,
+ # bias=False,
+ # quant_config=quant_config,
+ # prefix=f"{prefix}.down_proj",
+ # )
+ self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. "
@@ -78,12 +115,65 @@ class LlamaMLP(nn.Module):
self.act_fn = SiluAndMul()
def forward(self, x):
- gate_up, _ = self.gate_up_proj(x)
+ gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
- x, _ = self.down_proj(x)
+ x = self.down_proj(x)
return x
+def _get_shard_offset_mapping(self, loaded_shard_id: str):
+ shard_offset_mapping = {
+ "q": 0,
+ "k": self.num_heads * self.head_size,
+ "v": (self.num_heads + self.num_kv_heads) * self.head_size,
+ "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
+ }
+ return shard_offset_mapping.get(loaded_shard_id)
+
+def _get_shard_size_mapping(self, loaded_shard_id: str):
+ shard_size_mapping = {
+ "q": self.num_heads * self.head_size,
+ "k": self.num_kv_heads * self.head_size,
+ "v": self.num_kv_heads * self.head_size,
+ }
+ return shard_size_mapping.get(loaded_shard_id)
+
+def qkv_proj_weight_loader(
+ self,
+ param: Parameter,
+ loaded_weight: torch.Tensor,
+ loaded_shard_id: Optional[str] = None,
+):
+ if loaded_shard_id is None:
+ shard_offsets = [
+ # (shard_id, shard_offset, shard_size)
+ ("q", 0, self.total_num_heads * self.head_size),
+ (
+ "k",
+ self.total_num_heads * self.head_size,
+ self.total_num_kv_heads * self.head_size,
+ ),
+ (
+ "v",
+ (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
+ self.total_num_kv_heads * self.head_size,
+ ),
+ ]
+ for shard_id, shard_offset, shard_size in shard_offsets:
+ loaded_weight_shard = loaded_weight.narrow(
+ param.output_dim, shard_offset, shard_size
+ )
+ self.weight_loader(param, loaded_weight_shard, shard_id)
+ else:
+ shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
+ shard_size = self._get_shard_size_mapping(loaded_shard_id)
+ param_data = param.data
+ param_data = param_data.narrow(0, shard_offset, shard_size)
+ assert param_data.shape == loaded_weight.shape
+ param_data.copy_(loaded_weight)
+ return
+
+
class LlamaAttention(nn.Module):
def __init__(
self,
@@ -125,23 +215,42 @@ class LlamaAttention(nn.Module):
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
- self.qkv_proj = QKVParallelLinear(
+ # self.qkv_proj = QKVParallelLinear(
+ # hidden_size,
+ # self.head_dim,
+ # self.total_num_heads,
+ # self.total_num_kv_heads,
+ # bias=False,
+ # quant_config=quant_config,
+ # prefix=f"{prefix}.qkv_proj",
+ # )
+ self.qkv_proj = torch.nn.Linear(
hidden_size,
- self.head_dim,
- self.total_num_heads,
- self.total_num_kv_heads,
+ (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim,
bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.qkv_proj",
)
- self.o_proj = RowParallelLinear(
+ self.qkv_proj.total_num_heads = self.total_num_heads
+ self.qkv_proj.head_size = self.head_dim
+ self.qkv_proj.total_num_kv_heads = self.total_num_kv_heads
+ self.qkv_proj.num_heads = self.total_num_heads
+ self.qkv_proj.num_kv_heads = self.total_num_kv_heads
+ self.qkv_proj.weight_loader = types.MethodType(qkv_proj_weight_loader, self.qkv_proj)
+ self.qkv_proj._get_shard_offset_mapping = types.MethodType(_get_shard_offset_mapping, self.qkv_proj)
+ self.qkv_proj._get_shard_size_mapping = types.MethodType(_get_shard_size_mapping, self.qkv_proj)
+ self.qkv_proj.weight.weight_loader = self.qkv_proj.weight_loader
+ self.qkv_proj.weight.output_dim = 0
+ # self.o_proj = RowParallelLinear(
+ # self.total_num_heads * self.head_dim,
+ # hidden_size,
+ # bias=False,
+ # quant_config=quant_config,
+ # prefix=f"{prefix}.o_proj",
+ # )
+ self.o_proj = torch.nn.Linear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.o_proj",
)
-
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
@@ -164,11 +273,11 @@ class LlamaAttention(nn.Module):
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
- qkv, _ = self.qkv_proj(hidden_states)
+ qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
- output, _ = self.o_proj(attn_output)
+ output = self.o_proj(attn_output)
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment