Skip to content

Instantly share code, notes, and snippets.

@stillmatic
Last active May 3, 2024 03:58
Show Gist options
  • Save stillmatic/3d48e554f0ebc4da9f0234c4f9c14173 to your computer and use it in GitHub Desktop.
Save stillmatic/3d48e554f0ebc4da9f0234c4f9c14173 to your computer and use it in GitHub Desktop.
"""Benchmark offline inference throughput."""
import argparse
import json
import os
import random
import time
from typing import List, Optional, Tuple
import modal
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
MODEL_DIR = "/model"
MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
app = modal.App(
"benchmark_vllm_throughput",
)
def download_model_to_image(model_dir, model_name):
from huggingface_hub import snapshot_download
from transformers.utils import move_cache
os.makedirs(model_dir, exist_ok=True)
snapshot_download(
model_name,
local_dir=model_dir,
token=os.environ["HF_TOKEN"],
ignore_patterns=["*.pt", "*.gguf"], # Using safetensors
)
move_cache()
image = (
modal.Image.from_registry(
"nvidia/cuda:12.2.2-devel-ubuntu22.04",
add_python="3.10",
)
.apt_install("git")
.env({"TORCH_CUDA_ARCH_LIST": "8.0 8.6 8.9 9.0"})
.pip_install(
"git+https://github.com/vllm-project/vllm.git@468d761b32e3b3c5d64eeaa797e54ab809b7e50c",
"torch==2.2.1",
"transformers==4.40.1",
"ray==2.10.0",
"huggingface_hub==0.19.4",
"hf-transfer==0.1.4",
"packaging",
"wheel",
).run_commands(
"pip install flash-attn==2.5.7 --no-build-isolation",
)
# Use the barebones hf-transfer package for maximum download speeds. Varies from 100MB/s to 1.5 GB/s,
# so download times can vary from under a minute to tens of minutes.
# If your download slows down or times out, try interrupting and restarting.
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
.run_function(
download_model_to_image,
secrets=[modal.Secret.from_name("hf_read_token")],
timeout=60 * 20,
kwargs={"model_dir": MODEL_DIR, "model_name": MODEL_NAME},
)
)
app = modal.App(f"example-vllm-{MODEL_NAME}", image=image)
# with image.imports():
# from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
# import vllm
def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Shuffle the dataset.
random.shuffle(dataset)
# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = (
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
return filtered_dataset
def run_vllm(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
gpu_memory_utilization: float = 0.9,
download_dir: Optional[str] = None,
) -> float:
from vllm import LLM, SamplingParams
print(f"Using device {device}")
llm = LLM(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
)
# Add the requests to the engine.
for prompt, _, output_len in requests:
sampling_params = SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
)
# FIXME(woosuk): Do not use internal method.
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
sampling_params=sampling_params,
)
start = time.perf_counter()
# FIXME(woosuk): Do not use internal method.
llm._run_engine(use_tqdm=True)
end = time.perf_counter()
return end - start
def run_hf(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
use_beam_search: bool,
max_batch_size: int,
trust_remote_code: bool,
) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
llm = llm.cuda()
pbar = tqdm(total=len(requests))
start = time.perf_counter()
batch: List[str] = []
max_prompt_len = 0
max_output_len = 0
for i in range(len(requests)):
prompt, prompt_len, output_len = requests[i]
# Add the prompt to the batch.
batch.append(prompt)
max_prompt_len = max(max_prompt_len, prompt_len)
max_output_len = max(max_output_len, output_len)
if len(batch) < max_batch_size and i != len(requests) - 1:
# Check if we can add more requests to the batch.
_, next_prompt_len, next_output_len = requests[i + 1]
if (
max(max_prompt_len, next_prompt_len)
+ max(max_output_len, next_output_len)
) <= 2048:
# We can add more requests to the batch.
continue
# Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=not use_beam_search,
num_return_sequences=n,
temperature=1.0,
top_p=1.0,
use_cache=True,
max_new_tokens=max_output_len,
)
# Include the decoding time.
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
pbar.update(len(batch))
# Clear the batch.
batch = []
max_prompt_len = 0
max_output_len = 0
end = time.perf_counter()
return end - start
def run_mii(
requests: List[Tuple[str, int, int]],
model: str,
tensor_parallel_size: int,
output_len: int,
) -> float:
from mii import client, serve
llm = serve(model, tensor_parallel=tensor_parallel_size)
prompts = [prompt for prompt, _, _ in requests]
start = time.perf_counter()
llm.generate(prompts, max_new_tokens=output_len)
end = time.perf_counter()
client = client(model)
client.terminate_server()
return end - start
GPU_CONFIG = modal.gpu.H100(count=1)
@app.function(gpu=GPU_CONFIG, secrets=[modal.Secret.from_name("hf_read_token")])
def benchmark_throughput(
backend="vllm",
dataset=None,
input_len=None,
output_len=None,
model=MODEL_NAME,
tokenizer=None,
quantization=None,
tensor_parallel_size=1,
n=1,
use_beam_search=False,
num_prompts=1000,
seed=0,
hf_max_batch_size=None,
trust_remote_code=False,
max_model_len=None,
dtype="auto",
gpu_memory_utilization=0.9,
enforce_eager=False,
kv_cache_dtype="auto",
quantization_param_path=None,
device="cuda",
enable_prefix_caching=False,
enable_chunked_prefill=False,
max_num_batched_tokens=None,
download_dir=None,
):
print(f"Running benchmark with backend {backend}")
random.seed(seed)
if tokenizer is None:
tokenizer = model # Default tokenizer is the same as the model
# Set up tokenizer
# tokenizer = AutoTokenizer.from_pretrained(
# tokenizer, trust_remote_code=trust_remote_code
# )
if dataset is None:
assert (
input_len is not None and output_len is not None
), "input_len and output_len must be specified if no dataset is provided"
# Synthesize prompts if no dataset is provided
prompt = "hi" * (input_len - 1)
requests = [(prompt, input_len, output_len) for _ in range(num_prompts)]
else:
requests = sample_requests(dataset, num_prompts, tokenizer, output_len)
if backend == "vllm":
elapsed_time = run_vllm(
requests,
model,
tokenizer,
quantization,
tensor_parallel_size,
seed,
n,
use_beam_search,
trust_remote_code,
dtype,
max_model_len,
enforce_eager,
kv_cache_dtype,
quantization_param_path,
device,
enable_prefix_caching,
enable_chunked_prefill,
max_num_batched_tokens,
gpu_memory_utilization,
download_dir,
)
elif backend == "hf":
assert (
tensor_parallel_size == 1
), "tensor_parallel_size must be 1 for HF backend"
if hf_max_batch_size is None:
raise ValueError("HF max batch size is required for HF backend.")
elapsed_time = run_hf(
requests,
model,
tokenizer,
n,
use_beam_search,
hf_max_batch_size,
trust_remote_code,
)
elif backend == "mii":
elapsed_time = run_mii(requests, model, tensor_parallel_size, output_len)
else:
raise ValueError(f"Unknown backend: {backend}")
total_num_tokens = sum(
prompt_len + output_len for _, prompt_len, output_len in requests
)
print(
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s"
)
@app.local_entrypoint()
def main():
benchmark_throughput.remote(
input_len=1024,
output_len=256,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment