Skip to content

Instantly share code, notes, and snippets.

@jerryzh168
Created November 2, 2024 00:19
Show Gist options
  • Save jerryzh168/bd65f122f24d5c92525f2504a1ff5870 to your computer and use it in GitHub Desktop.
Save jerryzh168/bd65f122f24d5c92525f2504a1ff5870 to your computer and use it in GitHub Desktop.
diff --git a/python/pyproject.toml b/python/pyproject.toml
index d9749e1..fbcc0fd 100644
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -20,7 +20,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu
"orjson", "packaging", "pillow", "psutil", "pydantic", "python-multipart",
"torchao", "uvicorn", "uvloop", "zmq",
"outlines>=0.0.44", "modelscope"]
-srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"]
+srt = ["sglang[runtime_common]", "torch", "vllm"]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"]
diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py
index 94d48e8..f4125d1 100644
--- a/python/sglang/srt/layers/activation.py
+++ b/python/sglang/srt/layers/activation.py
@@ -38,6 +38,7 @@ from sglang.srt.utils import set_weight_attrs
logger = logging.getLogger(__name__)
[email protected]("silu_and_mul")
class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py
index 3ae392e..a81b3cf 100644
--- a/python/sglang/srt/layers/layernorm.py
+++ b/python/sglang/srt/layers/layernorm.py
@@ -35,7 +35,7 @@ from vllm.model_executor.custom_op import CustomOp
logger = logging.getLogger(__name__)
-
[email protected]("rms_norm")
class RMSNorm(CustomOp):
def __init__(
self,
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index 583cbd9..8681878 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -241,6 +241,7 @@ class ModelRunner:
self.load_config = LoadConfig(load_format=self.server_args.load_format)
self.vllm_model_config = VllmModelConfig(
model=self.server_args.model_path,
+ task="draft",
quantization=self.server_args.quantization,
tokenizer=None,
tokenizer_mode=None,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment