-
-
Save comaniac/ea26df17fdffa533cf53d53b8455bc31 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Benchmark offline inference throughput.""" | |
import argparse | |
import base64 | |
import dataclasses | |
import io | |
import math | |
import random | |
import time | |
from itertools import chain | |
from datasets import load_dataset | |
from PIL import Image | |
from transformers import AutoTokenizer, PreTrainedTokenizerBase | |
from vllm import LLM, SamplingParams | |
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs | |
from vllm.utils import FlexibleArgumentParser | |
def sample_mmmu_pro_vision_requests( | |
dataset, | |
num_requests: int, | |
tokenizer: PreTrainedTokenizerBase, | |
image_hit_rate: float, | |
): | |
sampled_requests = [] | |
num_unique_images = int(num_requests * (1 - image_hit_rate)) | |
print(f"Total {num_requests} requests with {num_unique_images} unique images") | |
dataset = dataset.take(num_unique_images) | |
for data in dataset: | |
# MMMU-Pro vision direct prompt | |
# Ref: https://github.com/MMMU-Benchmark/MMMU/blob/6ce42f4d8f70c1841c67867152648974415b5cac/mmmu-pro/prompts.yaml#L5 | |
prompt = ( | |
"Answer with the option letter from the given choices directly. " | |
"The last line of your response should be of the following " | |
"format: 'Answer: $LETTER' (without quotes) where LETTER is one of " | |
"options.") | |
image: Image = data["image"] | |
image = image.convert("RGB") | |
image_data = io.BytesIO() | |
image.save(image_data, format='JPEG') | |
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8") | |
mm_content = { | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{image_base64}" | |
}, | |
} | |
messages = [{ | |
"role": | |
"user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": prompt | |
}, | |
mm_content, | |
], | |
}] | |
sampled_requests.append(messages) | |
n = math.ceil(num_requests / num_unique_images) | |
sampled_requests = list(chain.from_iterable([x] * n for x in sampled_requests))[:num_requests] | |
return sampled_requests | |
def sample_hf_requests( | |
num_requests: int, | |
tokenizer: PreTrainedTokenizerBase, | |
random_seed: int, | |
image_hit_rate: float, | |
): | |
dataset = load_dataset('MMMU/MMMU_Pro', | |
name='vision', | |
split="test", | |
streaming=True) | |
dataset = dataset.shuffle(seed=random_seed) | |
return sample_mmmu_pro_vision_requests(dataset, num_requests, tokenizer, image_hit_rate) | |
def main(args: argparse.Namespace): | |
print(args) | |
random.seed(args.seed) | |
engine_args = EngineArgs.from_cli_args(args) | |
# Sample the requests. | |
tokenizer = AutoTokenizer.from_pretrained( | |
args.tokenizer, trust_remote_code=args.trust_remote_code) | |
sampled = sample_hf_requests(args.num_prompts, tokenizer, args.seed, args.image_hit_rate) | |
llm = LLM(**dataclasses.asdict(engine_args)) | |
sampling_params = SamplingParams(max_tokens=args.output_len, | |
temperature=0.01) | |
st = time.perf_counter() | |
llm.chat(sampled, sampling_params=sampling_params) | |
print( | |
f"Throughput: {args.num_prompts / (time.perf_counter() - st):.2f} req/s" | |
) | |
if __name__ == "__main__": | |
parser = FlexibleArgumentParser(description="Benchmark the throughput.") | |
parser.add_argument("--output-len", | |
type=int, | |
default=128, | |
help="Output length for each request. Overrides the " | |
"output length from the dataset.") | |
parser.add_argument("--num-prompts", | |
type=int, | |
default=1000, | |
help="Number of prompts to process.") | |
parser.add_argument("--image-hit-rate", | |
type=float, | |
default=0.0, | |
help="Image hit rate between 0 and 1.") | |
parser = AsyncEngineArgs.add_cli_args(parser) | |
args = parser.parse_args() | |
if args.tokenizer is None: | |
args.tokenizer = args.model | |
main(args) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment