Skip to content

Instantly share code, notes, and snippets.

@qubvel
Last active July 27, 2024 09:33
Show Gist options
  • Save qubvel/ac691a54e54f9fae8144275f866a7ff8 to your computer and use it in GitHub Desktop.
Save qubvel/ac691a54e54f9fae8144275f866a7ff8 to your computer and use it in GitHub Desktop.
import sys
import time
import requests
import itertools
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
import torch
import transformers
print("\n## Environment:\n")
print("Python version:", sys.version)
print("Transformers version:", transformers.__version__)
print("Torch version:", torch.__version__)
if torch.cuda.is_available():
print("GPU:", torch.cuda.get_device_name(0))
@torch.no_grad()
def get_model_iteration_time(model, inputs, device, min_iterations=100, min_benchmark_time=4, warm_up_steps=10):
with torch.autocast(device):
for _ in range(warm_up_steps):
model(**inputs)
timings = []
iterations = 0
benchmark_time = 0
torch.cuda.synchronize()
while benchmark_time < min_benchmark_time or iterations < min_iterations:
for _ in range(10):
start_time = time.time()
_ = model(**inputs)
torch.cuda.synchronize()
end_time = time.time()
elapsed_time = end_time - start_time
# store the time
timings.append(elapsed_time)
# update the benchmark time and iterations
benchmark_time += elapsed_time
iterations += 1
median_time = np.median(timings)
ci = 1.96 * np.array(timings).std() / np.sqrt(len(timings))
return median_time, ci
def prepare_inputs(processor, image_batch_size=None, text_batch_size=None):
# loading image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
images = [image] * image_batch_size if image_batch_size is not None else None
texts = ["a photo of 2 cats"] * text_batch_size if text_batch_size is not None else None
inputs = processor(text=texts, images=images, padding="max_length", return_tensors="pt")
return inputs
def format_df(df) -> str:
# format float numbers
for column in df.columns:
if "CI" in column and "%" in column:
df[column] = df[column].apply(lambda x: f"±{x:.1f}%")
else:
df[column] = df[column].apply(lambda x: f"{x:.3f}")
# rename columns
columns_mapping = {
"image_batch_size": "Image batch size",
"text_batch_size": "Num text labels",
"Eager": "Eager (s/iter)",
"FA2": "FA2 (s/iter)",
"SDPA": "SDPA (s/iter)",
}
for column_name, new_column_name in columns_mapping.items():
if column_name in df.columns:
df = df.rename(columns={column_name: new_column_name})
# format as markdown table
markdown = df.to_markdown(index=False)
return markdown
def benchmark(models_dict, processor, device, image_batch_sizes=None, text_batch_sizes=None, n_iterations=100):
image_batch_sizes = image_batch_sizes or [None]
text_batch_sizes = text_batch_sizes or [None]
cases = list(itertools.product(image_batch_sizes, text_batch_sizes))
results = []
for image_batch_size, text_batch_size in tqdm(cases):
inputs = prepare_inputs(processor, image_batch_size, text_batch_size).to(device)
step_results = {}
if image_batch_size is not None:
step_results["image_batch_size"] = image_batch_size
if text_batch_size is not None:
step_results["text_batch_size"] = text_batch_size
for attn_name, model in models_dict.items():
mean_time, confidence_interval = get_model_iteration_time(
model, inputs, device, min_iterations=n_iterations, min_benchmark_time=4
)
step_results[f"{attn_name}"] = mean_time
confidence_interval_percent = (confidence_interval / mean_time) * 100
step_results[f"{attn_name} CI, %"] = confidence_interval_percent
if attn_name != "Eager":
step_results[f"{attn_name} speedup"] = step_results["Eager"] / mean_time
results.append(step_results)
df = pd.DataFrame(results)
markdown = format_df(df)
return markdown
def load_models(model_class, checkpoint, dtype, device):
models_dict = {
"Eager": model_class.from_pretrained(checkpoint, attn_implementation="eager", torch_dtype=dtype, device_map=device).eval()
}
if model_class._supports_flash_attn_2:
models_dict["FA2"] = model_class.from_pretrained(checkpoint, attn_implementation="flash_attention_2", torch_dtype=dtype, device_map=device).eval()
if model_class._supports_sdpa:
models_dict["SDPA"] = model_class.from_pretrained(checkpoint, attn_implementation="sdpa", torch_dtype=dtype, device_map=device).eval()
return models_dict
if __name__ == "__main__":
import argparse
from transformers import AutoProcessor, CLIPModel, CLIPTextModel, CLIPVisionModel
parser = argparse.ArgumentParser()
parser.add_argument("--n_iterations", type=int, default=100)
parser.add_argument("--checkpoint", type=str, default="openai/clip-vit-large-patch14")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--dtype", type=str, default="float16")
args = parser.parse_args()
benchmark_multimodal = True
benchmark_text = True
benchmark_vision = True
dtype = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}[args.dtype]
processor = AutoProcessor.from_pretrained(args.checkpoint)
print("\n## Benchmark results\n")
# ---------------------------
# Multi-modal model
# ---------------------------
if benchmark_multimodal:
models_dict = load_models(CLIPModel, args.checkpoint, dtype, args.device)
result = benchmark(
models_dict,
processor,
image_batch_sizes=[1, 4, 16, 32],
text_batch_sizes=[4, 16, 32, 64],
device=args.device,
n_iterations=args.n_iterations,
)
print(f"\n### {CLIPModel.__name__}\n")
print(result)
print()
# ---------------------------
# Text model
# ---------------------------
if benchmark_text:
models_dict = load_models(CLIPTextModel, args.checkpoint, dtype, args.device)
result = benchmark(
models_dict,
processor,
text_batch_sizes=[4, 16, 32, 64, 128],
device=args.device,
n_iterations=args.n_iterations,
)
print(f"\n### {CLIPTextModel.__name__}\n")
print(result)
print()
# ---------------------------
# Vision model
# ---------------------------
if benchmark_vision:
models_dict = load_models(CLIPVisionModel, args.checkpoint, dtype, args.device)
result = benchmark(
models_dict,
processor,
image_batch_sizes=[1, 4, 16, 32],
device=args.device,
n_iterations=args.n_iterations,
)
print(f"\n### {CLIPVisionModel.__name__}\n")
print(result)
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment