Last active
May 3, 2024 03:58
-
-
Save stillmatic/3d48e554f0ebc4da9f0234c4f9c14173 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Benchmark offline inference throughput.""" | |
import argparse | |
import json | |
import os | |
import random | |
import time | |
from typing import List, Optional, Tuple | |
import modal | |
import torch | |
from tqdm import tqdm | |
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase | |
MODEL_DIR = "/model" | |
MODEL_NAME = "meta-llama/Meta-Llama-3-8B" | |
app = modal.App( | |
"benchmark_vllm_throughput", | |
) | |
def download_model_to_image(model_dir, model_name): | |
from huggingface_hub import snapshot_download | |
from transformers.utils import move_cache | |
os.makedirs(model_dir, exist_ok=True) | |
snapshot_download( | |
model_name, | |
local_dir=model_dir, | |
token=os.environ["HF_TOKEN"], | |
ignore_patterns=["*.pt", "*.gguf"], # Using safetensors | |
) | |
move_cache() | |
image = ( | |
modal.Image.from_registry( | |
"nvidia/cuda:12.2.2-devel-ubuntu22.04", | |
add_python="3.10", | |
) | |
.apt_install("git") | |
.env({"TORCH_CUDA_ARCH_LIST": "8.0 8.6 8.9 9.0"}) | |
.pip_install( | |
"git+https://github.com/vllm-project/vllm.git@468d761b32e3b3c5d64eeaa797e54ab809b7e50c", | |
"torch==2.2.1", | |
"transformers==4.40.1", | |
"ray==2.10.0", | |
"huggingface_hub==0.19.4", | |
"hf-transfer==0.1.4", | |
"packaging", | |
"wheel", | |
).run_commands( | |
"pip install flash-attn==2.5.7 --no-build-isolation", | |
) | |
# Use the barebones hf-transfer package for maximum download speeds. Varies from 100MB/s to 1.5 GB/s, | |
# so download times can vary from under a minute to tens of minutes. | |
# If your download slows down or times out, try interrupting and restarting. | |
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) | |
.run_function( | |
download_model_to_image, | |
secrets=[modal.Secret.from_name("hf_read_token")], | |
timeout=60 * 20, | |
kwargs={"model_dir": MODEL_DIR, "model_name": MODEL_NAME}, | |
) | |
) | |
app = modal.App(f"example-vllm-{MODEL_NAME}", image=image) | |
# with image.imports(): | |
# from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS | |
# import vllm | |
def sample_requests( | |
dataset_path: str, | |
num_requests: int, | |
tokenizer: PreTrainedTokenizerBase, | |
fixed_output_len: Optional[int], | |
) -> List[Tuple[str, int, int]]: | |
if fixed_output_len is not None and fixed_output_len < 4: | |
raise ValueError("output_len too small") | |
# Load the dataset. | |
with open(dataset_path) as f: | |
dataset = json.load(f) | |
# Filter out the conversations with less than 2 turns. | |
dataset = [data for data in dataset if len(data["conversations"]) >= 2] | |
# Only keep the first two turns of each conversation. | |
dataset = [ | |
(data["conversations"][0]["value"], data["conversations"][1]["value"]) | |
for data in dataset | |
] | |
# Shuffle the dataset. | |
random.shuffle(dataset) | |
# Filter out sequences that are too long or too short | |
filtered_dataset: List[Tuple[str, int, int]] = [] | |
for i in range(len(dataset)): | |
if len(filtered_dataset) == num_requests: | |
break | |
# Tokenize the prompts and completions. | |
prompt = dataset[i][0] | |
prompt_token_ids = tokenizer(prompt).input_ids | |
completion = dataset[i][1] | |
completion_token_ids = tokenizer(completion).input_ids | |
prompt_len = len(prompt_token_ids) | |
output_len = ( | |
len(completion_token_ids) if fixed_output_len is None else fixed_output_len | |
) | |
if prompt_len < 4 or output_len < 4: | |
# Prune too short sequences. | |
continue | |
if prompt_len > 1024 or prompt_len + output_len > 2048: | |
# Prune too long sequences. | |
continue | |
filtered_dataset.append((prompt, prompt_len, output_len)) | |
return filtered_dataset | |
def run_vllm( | |
requests: List[Tuple[str, int, int]], | |
model: str, | |
tokenizer: str, | |
quantization: Optional[str], | |
tensor_parallel_size: int, | |
seed: int, | |
n: int, | |
use_beam_search: bool, | |
trust_remote_code: bool, | |
dtype: str, | |
max_model_len: Optional[int], | |
enforce_eager: bool, | |
kv_cache_dtype: str, | |
quantization_param_path: Optional[str], | |
device: str, | |
enable_prefix_caching: bool, | |
enable_chunked_prefill: bool, | |
max_num_batched_tokens: int, | |
gpu_memory_utilization: float = 0.9, | |
download_dir: Optional[str] = None, | |
) -> float: | |
from vllm import LLM, SamplingParams | |
print(f"Using device {device}") | |
llm = LLM( | |
model=model, | |
tokenizer=tokenizer, | |
quantization=quantization, | |
tensor_parallel_size=tensor_parallel_size, | |
seed=seed, | |
trust_remote_code=trust_remote_code, | |
dtype=dtype, | |
max_model_len=max_model_len, | |
gpu_memory_utilization=gpu_memory_utilization, | |
enforce_eager=enforce_eager, | |
kv_cache_dtype=kv_cache_dtype, | |
quantization_param_path=quantization_param_path, | |
device=device, | |
enable_prefix_caching=enable_prefix_caching, | |
download_dir=download_dir, | |
enable_chunked_prefill=enable_chunked_prefill, | |
max_num_batched_tokens=max_num_batched_tokens, | |
) | |
# Add the requests to the engine. | |
for prompt, _, output_len in requests: | |
sampling_params = SamplingParams( | |
n=n, | |
temperature=0.0 if use_beam_search else 1.0, | |
top_p=1.0, | |
use_beam_search=use_beam_search, | |
ignore_eos=True, | |
max_tokens=output_len, | |
) | |
# FIXME(woosuk): Do not use internal method. | |
llm._add_request( | |
prompt=prompt, | |
prompt_token_ids=None, | |
sampling_params=sampling_params, | |
) | |
start = time.perf_counter() | |
# FIXME(woosuk): Do not use internal method. | |
llm._run_engine(use_tqdm=True) | |
end = time.perf_counter() | |
return end - start | |
def run_hf( | |
requests: List[Tuple[str, int, int]], | |
model: str, | |
tokenizer: PreTrainedTokenizerBase, | |
n: int, | |
use_beam_search: bool, | |
max_batch_size: int, | |
trust_remote_code: bool, | |
) -> float: | |
assert not use_beam_search | |
llm = AutoModelForCausalLM.from_pretrained( | |
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code | |
) | |
if llm.config.model_type == "llama": | |
# To enable padding in the HF backend. | |
tokenizer.pad_token = tokenizer.eos_token | |
llm = llm.cuda() | |
pbar = tqdm(total=len(requests)) | |
start = time.perf_counter() | |
batch: List[str] = [] | |
max_prompt_len = 0 | |
max_output_len = 0 | |
for i in range(len(requests)): | |
prompt, prompt_len, output_len = requests[i] | |
# Add the prompt to the batch. | |
batch.append(prompt) | |
max_prompt_len = max(max_prompt_len, prompt_len) | |
max_output_len = max(max_output_len, output_len) | |
if len(batch) < max_batch_size and i != len(requests) - 1: | |
# Check if we can add more requests to the batch. | |
_, next_prompt_len, next_output_len = requests[i + 1] | |
if ( | |
max(max_prompt_len, next_prompt_len) | |
+ max(max_output_len, next_output_len) | |
) <= 2048: | |
# We can add more requests to the batch. | |
continue | |
# Generate the sequences. | |
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids | |
llm_outputs = llm.generate( | |
input_ids=input_ids.cuda(), | |
do_sample=not use_beam_search, | |
num_return_sequences=n, | |
temperature=1.0, | |
top_p=1.0, | |
use_cache=True, | |
max_new_tokens=max_output_len, | |
) | |
# Include the decoding time. | |
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) | |
pbar.update(len(batch)) | |
# Clear the batch. | |
batch = [] | |
max_prompt_len = 0 | |
max_output_len = 0 | |
end = time.perf_counter() | |
return end - start | |
def run_mii( | |
requests: List[Tuple[str, int, int]], | |
model: str, | |
tensor_parallel_size: int, | |
output_len: int, | |
) -> float: | |
from mii import client, serve | |
llm = serve(model, tensor_parallel=tensor_parallel_size) | |
prompts = [prompt for prompt, _, _ in requests] | |
start = time.perf_counter() | |
llm.generate(prompts, max_new_tokens=output_len) | |
end = time.perf_counter() | |
client = client(model) | |
client.terminate_server() | |
return end - start | |
GPU_CONFIG = modal.gpu.H100(count=1) | |
@app.function(gpu=GPU_CONFIG, secrets=[modal.Secret.from_name("hf_read_token")]) | |
def benchmark_throughput( | |
backend="vllm", | |
dataset=None, | |
input_len=None, | |
output_len=None, | |
model=MODEL_NAME, | |
tokenizer=None, | |
quantization=None, | |
tensor_parallel_size=1, | |
n=1, | |
use_beam_search=False, | |
num_prompts=1000, | |
seed=0, | |
hf_max_batch_size=None, | |
trust_remote_code=False, | |
max_model_len=None, | |
dtype="auto", | |
gpu_memory_utilization=0.9, | |
enforce_eager=False, | |
kv_cache_dtype="auto", | |
quantization_param_path=None, | |
device="cuda", | |
enable_prefix_caching=False, | |
enable_chunked_prefill=False, | |
max_num_batched_tokens=None, | |
download_dir=None, | |
): | |
print(f"Running benchmark with backend {backend}") | |
random.seed(seed) | |
if tokenizer is None: | |
tokenizer = model # Default tokenizer is the same as the model | |
# Set up tokenizer | |
# tokenizer = AutoTokenizer.from_pretrained( | |
# tokenizer, trust_remote_code=trust_remote_code | |
# ) | |
if dataset is None: | |
assert ( | |
input_len is not None and output_len is not None | |
), "input_len and output_len must be specified if no dataset is provided" | |
# Synthesize prompts if no dataset is provided | |
prompt = "hi" * (input_len - 1) | |
requests = [(prompt, input_len, output_len) for _ in range(num_prompts)] | |
else: | |
requests = sample_requests(dataset, num_prompts, tokenizer, output_len) | |
if backend == "vllm": | |
elapsed_time = run_vllm( | |
requests, | |
model, | |
tokenizer, | |
quantization, | |
tensor_parallel_size, | |
seed, | |
n, | |
use_beam_search, | |
trust_remote_code, | |
dtype, | |
max_model_len, | |
enforce_eager, | |
kv_cache_dtype, | |
quantization_param_path, | |
device, | |
enable_prefix_caching, | |
enable_chunked_prefill, | |
max_num_batched_tokens, | |
gpu_memory_utilization, | |
download_dir, | |
) | |
elif backend == "hf": | |
assert ( | |
tensor_parallel_size == 1 | |
), "tensor_parallel_size must be 1 for HF backend" | |
if hf_max_batch_size is None: | |
raise ValueError("HF max batch size is required for HF backend.") | |
elapsed_time = run_hf( | |
requests, | |
model, | |
tokenizer, | |
n, | |
use_beam_search, | |
hf_max_batch_size, | |
trust_remote_code, | |
) | |
elif backend == "mii": | |
elapsed_time = run_mii(requests, model, tensor_parallel_size, output_len) | |
else: | |
raise ValueError(f"Unknown backend: {backend}") | |
total_num_tokens = sum( | |
prompt_len + output_len for _, prompt_len, output_len in requests | |
) | |
print( | |
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " | |
f"{total_num_tokens / elapsed_time:.2f} tokens/s" | |
) | |
@app.local_entrypoint() | |
def main(): | |
benchmark_throughput.remote( | |
input_len=1024, | |
output_len=256, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment