This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/python/pyproject.toml b/python/pyproject.toml | |
index d9749e1..fbcc0fd 100644 | |
--- a/python/pyproject.toml | |
+++ b/python/pyproject.toml | |
@@ -20,7 +20,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu | |
"orjson", "packaging", "pillow", "psutil", "pydantic", "python-multipart", | |
"torchao", "uvicorn", "uvloop", "zmq", | |
"outlines>=0.0.44", "modelscope"] | |
-srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"] | |
+srt = ["sglang[runtime_common]", "torch", "vllm"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mport torch | |
from diffusers import FluxPipeline, FluxTransformer2DModel | |
import torch.utils.benchmark as benchmark | |
from functools import partial | |
def get_example_inputs(): | |
example_inputs = { | |
"hidden_states": torch.randn(1, 4096, 64, dtype=torch.bfloat16, device="cuda"), | |
"encoder_hidden_states": torch.randn(1, 512, 4096, dtype=torch.bfloat16, device="cuda"), | |
"pooled_projections": torch.randn(1, 768, dtype=torch.bfloat16, device="cuda"), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchvision import models | |
import torch | |
## compilation configs | |
torch._dynamo.config.automatic_dynamic_shapes = False | |
torch._inductor.config.force_fuse_int_mm_with_mul = True | |
torch._inductor.config.use_mixed_mm = True | |
## compilation configs end | |
# temporary workaround to recover the perf with quantized model under torch.compile | |
torch.backends.mha.set_fastpath_enabled(False) | |
import torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py | |
index b63aaf1..9c268ab 100644 | |
--- a/python/sglang/srt/models/llama.py | |
+++ b/python/sglang/srt/models/llama.py | |
@@ -18,6 +18,7 @@ limitations under the License. | |
"""Inference-only LLaMA model compatible with HuggingFace weights.""" | |
from typing import Any, Dict, Iterable, Optional, Tuple | |
+from torch.nn.parameter import Parameter | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
baseline (no tp) | |
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 | |
[15:07:14 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=79.41 GB | |
[15:07:14 TP0] Memory pool end. avail mem=11.16 GB | |
[15:07:14 TP0] Capture cuda graph begin. This can take up to several minutes. | |
max_total_num_tokens=557684 | |
Warmup ... | |
Prefill. latency: 0.03870 s, throughput: 3307.61 token/s |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[rank0]: run_once() | |
[rank0]: File "/data/users/jerryzh/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 199, in run_once | |
[rank0]: return forward(input_ids, input_metadata.positions, input_metadata) | |
[rank0]: File "/home/jerryzh/anaconda3/envs/sglang/lib/python3.10/site-packages/torch-2.4.0-py3.10-linux-x86_64.egg/torch/utils/_contextlib.py", line 116, in decorate_context | |
[rank0]: return func(*args, **kwargs) | |
[rank0]: File "/data/users/jerryzh/sglang/python/sglang/srt/models/llama.py", line 320, in forward | |
[rank0]: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) | |
[rank0]: File "/home/jerryzh/anaconda3/envs/sglang/lib/python3.10/site-packages/torch-2.4.0-py3.10-linux-x86_64.egg/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl | |
[rank0]: return self._call_impl(*args, **kwargs) | |
[rank0]: File "/home/jerryzh/anaconda3/envs/sglang/lib/python3.10/site-packages/torch-2.4.0-py3.10-linux-x86_64.egg/torch/nn/modules/module.py", line |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer | |
import torch.utils.benchmark as benchmark | |
def benchmark_fn(f, *args, **kwargs): | |
t0 = benchmark.Timer( | |
stmt="f(*args, **kwargs)", | |
globals={"args": args, "kwargs": kwargs, "f": f}, | |
num_threads=torch.get_num_threads(), | |
) | |
return f"{(t0.blocked_autorange().mean):.3f}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
+ @common_utils.parametrize("device", COMMON_DEVICES) | |
+ @common_utils.parametrize("dtype", COMMON_DTYPES) | |
+ def test_linear_compile(self, device, dtype): | |
+ hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) | |
+ lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) | |
+ | |
+ hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype) | |
+ hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor) | |
+ l = torch.nn.Linear(128, 4, bias=False, device=device, dtype=dtype) | |
+ l.weight = torch.nn.Parameter(lp_tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
...........frames [('total', 1), ('ok', 1)] | |
inductor [('pattern_matcher_count', 4), ('pattern_matcher_nodes', 4), ('fxgraph_cache_miss', 1), ('extern_calls', 1)] | |
inline_call [] | |
stats [('calls_captured', 1), ('unique_graphs', 1)] | |
aot_autograd [('total', 1), ('ok', 1)] | |
.frames [('total', 1), ('ok', 1)] | |
inductor [('pattern_matcher_count', 4), ('pattern_matcher_nodes', 4), ('fxgraph_cache_miss', 1), ('extern_calls', 1)] | |
inline_call [] | |
stats [('calls_captured', 1), ('unique_graphs', 1)] | |
aot_autograd [('total', 1), ('ok', 1)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from diffusers import FluxTransformer2DModel | |
from torchao.quantization import quantize_, int8_weight_only | |
import torch | |
from torchao import autoquant | |
ckpt_id = "black-forest-labs/FLUX.1-schnell" | |
transformer = FluxTransformer2DModel.from_pretrained( | |
ckpt_id, subfolder="transformer", torch_dtype=torch.bfloat16 | |
) |
NewerOlder