Created
October 20, 2025 21:38
-
-
Save ita9naiwa/4aa5f323ef9e2b87400d20ca8366a1c9 to your computer and use it in GitHub Desktop.
vllm_benchmark
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
in_len=1024; out_len=1024; batch_size=128; | |
model=meta-llama/Llama-3.1-8B-Instruct | |
VLLM_USE_V1=1 VLLM_DISABLE_COMPILE_CACHE=1 TRITON_PRINT_AUTOTUNING=1 python3 vllm_benchmark.py --input-len $in_len --output-len $out_len --model $model --dtype float16 --batch-size $batch_size --num_iters_warmup 5 --num_iters 5 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# SPDX-License-Identifier: Apache-2.0 | |
"""Benchmark the latency of processing a single batch of requests.""" | |
import argparse | |
import dataclasses | |
import json | |
import os | |
import time | |
from pathlib import Path | |
from typing import Any, Optional | |
import numpy as np | |
import torch | |
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json | |
from tqdm import tqdm | |
from vllm import LLM, SamplingParams | |
from vllm.engine.arg_utils import EngineArgs | |
from vllm.inputs import PromptType | |
from vllm.sampling_params import BeamSearchParams | |
from vllm.utils import FlexibleArgumentParser | |
########################################################################################## | |
#pip install git+https://github.com/mobiusml/gemlite --upgrade; | |
#git clone https://github.com/vllm-project/vllm.git; cd vllm; VLLM_USE_PRECOMPILED=1 pip install -e .; cd ~; | |
#wget https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/benchmarks/benchmark_utils.py; | |
# import logging | |
# import torch._dynamo as dynamo | |
# dynamo.config.log_level = logging.INFO | |
# dynamo.config.verbose = True | |
import gemlite | |
#gemlite.set_packing_bitwidth(32) | |
#cache_name = 'test_config.json' | |
#gemlite.reset_config() | |
#gemlite.load_config(cache_name) | |
#gemlite.set_autotune("max") #max / fast | |
#gemlite.set_autotune(dict([(m, "max") for m in ['GEMV_SPLITK', 'GEMV_REVSPLITK', 'GEMV']])) | |
#gemlite.set_autotune(dict([(m, "max") for m in ['GEMM_SPLITK']])) | |
#gemlite.set_kernel_caching(True) | |
#gemlite.core.GEMLITE_TRITON_CONFIG_CACHE['GEMV_REVSPLITK'] = {} | |
#gemlite.core.GEMLITE_TRITON_CONFIG_CACHE['GEMM_SPLITK'] = {} | |
#gemlite.core.GEMLITE_TRITON_CONFIG_CACHE['GEMM'] = {} | |
#gemlite.core.GEMLITE_TRITON_CONFIG_CACHE['GEMM_SPLITK', 'GEMM'] = {} | |
#gemlite.set_autotune(dict([(m, "max") for m in ['GEMM_SPLITK', 'GEMM']])) | |
# gemlite.set_packing_bitwidth(32) | |
# gemlite.core.get_default_gemv = lambda W_nbits:'GEMV_REVSPLITK' if (W_nbits < 8) else 'GEMV_SPLITK' | |
#gemlite.set_packing_bitwidth(8) | |
#gemlite.core.get_default_gemv = lambda W_nbits:'GEMV' if (W_nbits < 8) else 'GEMV_SPLITK' | |
from hqq.utils.vllm import set_vllm_onthefly_hqq_quant | |
skip_modules = ['lm_head', 'visual', 'vision'] | |
#set_vllm_onthefly_hqq_quant(weight_bits=8, group_size=None, quant_mode='int8_weightonly', skip_modules=skip_modules) #A16W8 - INT8 weight only | |
#set_vllm_onthefly_hqq_quant(weight_bits=4, group_size=128, quant_mode='int4_weightonly', skip_modules=skip_modules) #A16W4 - HQQ weight only | |
#set_vllm_onthefly_hqq_quant(weight_bits=8, group_size=None, quant_mode='int8_dynamic', skip_modules=skip_modules) #A8W8 - INT8 x INT8 dynamic | |
#set_vllm_onthefly_hqq_quant(weight_bits=8, group_size=None, quant_mode='fp8_dynamic', skip_modules=skip_modules) #A8W8 - FP8 x FP8 dynamic | |
# set_vllm_onthefly_hqq_quant(weight_bits=8, group_size=None, quant_mode='mxfp8_dynamic', skip_modules=skip_modules) #A8W8 - MXFP8 x MXPF8 - post_scale=True | |
# set_vllm_onthefly_hqq_quant(weight_bits=8, group_size=32, quant_mode='mxfp8_dynamic', skip_modules=skip_modules) #A8W8 - MXFP8 x MXPF8- post_scale=False | |
# set_vllm_onthefly_hqq_quant(weight_bits=4, quant_mode='mxfp4_weightonly', skip_modules=skip_modules) #A16W4 - MXFP4 weight-only | |
# set_vllm_onthefly_hqq_quant(weight_bits=4, quant_mode='mxfp8_dynamic', skip_modules=skip_modules) #A8W4 - MXFP8 x MXFP4 dynamic -OK | |
set_vllm_onthefly_hqq_quant(weight_bits=4, quant_mode='mxfp4_dynamic', skip_modules=skip_modules) #A4W4 - MXPF4 x MXPF4 dynamic | |
# set_vllm_onthefly_hqq_quant(weight_bits=4, quant_mode='nvfp4_dynamic', skip_modules=skip_modules) #A4W4 - NVFP4 x NVFP4 dynamic | |
max_model_len = 4096 #4096, 16384 | |
compilation_config = None | |
# #Patching | |
# target_size = 16 | |
# #gemlite.core.get_matmul_type = lambda batch_size, W_nbits: "GEMM_SPLITK" | |
# gemlite.triton_kernels.utils.get_closest_m = lambda M: max_model_len if M > target_size else target_size | |
# compilation_config = {"cudagraph_num_of_warmups":1, "cudagraph_capture_sizes": [max_model_len, target_size], "compile_sizes": [max_model_len, target_size], "max_capture_size": max_model_len} | |
#in_len=256; out_len=1024; batch_size=1; | |
#model=meta-llama/Llama-3.1-8B-Instruct | |
#model=microsoft/Phi-4-mini-instruct | |
#model=Qwen/Qwen2.5-7B-Instruct | |
#model=Qwen/Qwen2.5-3B-Instruct | |
#model=meta-llama/Llama-3.2-3B | |
#model=Qwen/Qwen2.5-14B-Instruct | |
#model=hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 | |
#model=mobiuslabsgmbh/Meta-Llama-3-8B-Instruct_4bitgs64_hqq_hf | |
#VLLM_USE_V1=1 VLLM_DISABLE_COMPILE_CACHE=1 TRITON_PRINT_AUTOTUNING=1 python3 vllm_benchmark.py --input-len $in_len --output-len $out_len --model $model --dtype float16 --batch-size $batch_size --num_iters_warmup 5 --num_iters 5 | |
########################################################################################## | |
def save_to_pytorch_benchmark_format(args: argparse.Namespace, | |
results: dict[str, Any]) -> None: | |
pt_records = convert_to_pytorch_benchmark_format( | |
args=args, | |
metrics={"latency": results["latencies"]}, | |
extra_info={k: results[k] | |
for k in ["avg_latency", "percentiles"]}) | |
if pt_records: | |
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" | |
write_to_json(pt_file, pt_records) | |
def main(args: argparse.Namespace): | |
print(args) | |
engine_args = EngineArgs.from_cli_args(args) | |
# NOTE(woosuk): If the request cannot be processed in a single batch, | |
# the engine will automatically process the request in multiple batches. | |
params = dataclasses.asdict(engine_args) | |
params['gpu_memory_utilization'] = 0.9 | |
params['max_model_len'] = max_model_len #4096 | |
if(compilation_config is not None): | |
params['compilation_config'] = compilation_config | |
llm = LLM(**params) | |
assert llm.llm_engine.model_config.max_model_len >= ( | |
args.input_len + | |
args.output_len), ("Please ensure that max_model_len is greater than" | |
" the sum of input_len and output_len.") | |
sampling_params = SamplingParams( | |
n=args.n, | |
temperature=1.0, | |
top_p=1.0, | |
ignore_eos=True, | |
max_tokens=args.output_len, | |
detokenize=not args.disable_detokenize, | |
) | |
print(sampling_params) | |
dummy_prompt_token_ids = np.random.randint(10000, | |
size=(args.batch_size, | |
args.input_len)) | |
dummy_prompts: list[PromptType] = [{ | |
"prompt_token_ids": batch | |
} for batch in dummy_prompt_token_ids.tolist()] | |
def llm_generate(): | |
if not args.use_beam_search: | |
llm.generate(dummy_prompts, | |
sampling_params=sampling_params, | |
use_tqdm=False) | |
else: | |
llm.beam_search( | |
dummy_prompts, | |
BeamSearchParams( | |
beam_width=args.n, | |
max_tokens=args.output_len, | |
ignore_eos=True, | |
), | |
) | |
def run_to_completion(profile_dir: Optional[str] = None): | |
if profile_dir: | |
with torch.profiler.profile( | |
activities=[ | |
torch.profiler.ProfilerActivity.CPU, | |
torch.profiler.ProfilerActivity.CUDA, | |
], | |
on_trace_ready=torch.profiler.tensorboard_trace_handler( | |
str(profile_dir)), | |
) as p: | |
llm_generate() | |
print(p.key_averages().table(sort_by="self_cuda_time_total")) | |
else: | |
start_time = time.perf_counter() | |
llm_generate() | |
end_time = time.perf_counter() | |
latency = end_time - start_time | |
return latency | |
print("Warming up...") | |
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): | |
run_to_completion(profile_dir=None) | |
if args.profile: | |
profile_dir = args.profile_result_dir | |
if not profile_dir: | |
profile_dir = (Path(".") / "vllm_benchmark_result" / | |
f"latency_result_{time.time()}") | |
print(f"Profiling (results will be saved to '{profile_dir}')...") | |
run_to_completion(profile_dir=profile_dir) | |
return | |
# Benchmark. | |
latencies = [] | |
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): | |
latencies.append(run_to_completion(profile_dir=None)) | |
latencies = np.array(latencies) | |
percentages = [10, 25, 50, 75, 90, 99] | |
percentiles = np.percentile(latencies, percentages) | |
print(f"Avg latency: {np.mean(latencies)} seconds") | |
for percentage, percentile in zip(percentages, percentiles): | |
print(f"{percentage}% percentile latency: {percentile} seconds") | |
# Output JSON results if specified | |
if args.output_json: | |
results = { | |
"avg_latency": np.mean(latencies), | |
"latencies": latencies.tolist(), | |
"percentiles": dict(zip(percentages, percentiles.tolist())), | |
} | |
with open(args.output_json, "w") as f: | |
json.dump(results, f, indent=4) | |
save_to_pytorch_benchmark_format(args, results) | |
if __name__ == "__main__": | |
parser = FlexibleArgumentParser( | |
description="Benchmark the latency of processing a single batch of " | |
"requests till completion.") | |
parser.add_argument("--input-len", type=int, default=32) | |
parser.add_argument("--output-len", type=int, default=128) | |
parser.add_argument("--batch-size", type=int, default=8) | |
parser.add_argument( | |
"--n", | |
type=int, | |
default=1, | |
help="Number of generated sequences per prompt.", | |
) | |
parser.add_argument("--use-beam-search", action="store_true") | |
parser.add_argument( | |
"--num-iters-warmup", | |
type=int, | |
default=10, | |
help="Number of iterations to run for warmup.", | |
) | |
parser.add_argument("--num-iters", | |
type=int, | |
default=30, | |
help="Number of iterations to run.") | |
parser.add_argument( | |
"--profile", | |
action="store_true", | |
help="profile the generation process of a single batch", | |
) | |
parser.add_argument( | |
"--profile-result-dir", | |
type=str, | |
default=None, | |
help=("path to save the pytorch profiler output. Can be visualized " | |
"with ui.perfetto.dev or Tensorboard."), | |
) | |
parser.add_argument( | |
"--output-json", | |
type=str, | |
default=None, | |
help="Path to save the latency results in JSON format.", | |
) | |
parser.add_argument( | |
"--disable-detokenize", | |
action="store_true", | |
help=("Do not detokenize responses (i.e. do not include " | |
"detokenization time in the latency measurement)"), | |
) | |
parser = EngineArgs.add_cli_args(parser) | |
args = parser.parse_args() | |
main(args) | |
#gemlite.cache_config(cache_name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment