Skip to content

Instantly share code, notes, and snippets.

@comaniac
Created December 13, 2024 23:38
Show Gist options
  • Save comaniac/ea26df17fdffa533cf53d53b8455bc31 to your computer and use it in GitHub Desktop.
Save comaniac/ea26df17fdffa533cf53d53b8455bc31 to your computer and use it in GitHub Desktop.
"""Benchmark offline inference throughput."""
import argparse
import base64
import dataclasses
import io
import math
import random
import time
from itertools import chain
from datasets import load_dataset
from PIL import Image
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.utils import FlexibleArgumentParser
def sample_mmmu_pro_vision_requests(
dataset,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
image_hit_rate: float,
):
sampled_requests = []
num_unique_images = int(num_requests * (1 - image_hit_rate))
print(f"Total {num_requests} requests with {num_unique_images} unique images")
dataset = dataset.take(num_unique_images)
for data in dataset:
# MMMU-Pro vision direct prompt
# Ref: https://github.com/MMMU-Benchmark/MMMU/blob/6ce42f4d8f70c1841c67867152648974415b5cac/mmmu-pro/prompts.yaml#L5
prompt = (
"Answer with the option letter from the given choices directly. "
"The last line of your response should be of the following "
"format: 'Answer: $LETTER' (without quotes) where LETTER is one of "
"options.")
image: Image = data["image"]
image = image.convert("RGB")
image_data = io.BytesIO()
image.save(image_data, format='JPEG')
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
mm_content = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
}
messages = [{
"role":
"user",
"content": [
{
"type": "text",
"text": prompt
},
mm_content,
],
}]
sampled_requests.append(messages)
n = math.ceil(num_requests / num_unique_images)
sampled_requests = list(chain.from_iterable([x] * n for x in sampled_requests))[:num_requests]
return sampled_requests
def sample_hf_requests(
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
random_seed: int,
image_hit_rate: float,
):
dataset = load_dataset('MMMU/MMMU_Pro',
name='vision',
split="test",
streaming=True)
dataset = dataset.shuffle(seed=random_seed)
return sample_mmmu_pro_vision_requests(dataset, num_requests, tokenizer, image_hit_rate)
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
engine_args = EngineArgs.from_cli_args(args)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
sampled = sample_hf_requests(args.num_prompts, tokenizer, args.seed, args.image_hit_rate)
llm = LLM(**dataclasses.asdict(engine_args))
sampling_params = SamplingParams(max_tokens=args.output_len,
temperature=0.01)
st = time.perf_counter()
llm.chat(sampled, sampling_params=sampling_params)
print(
f"Throughput: {args.num_prompts / (time.perf_counter() - st):.2f} req/s"
)
if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--output-len",
type=int,
default=128,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--image-hit-rate",
type=float,
default=0.0,
help="Image hit rate between 0 and 1.")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment