Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Created August 16, 2023 14:11
Show Gist options
  • Save sayakpaul/0211e1f2742159cc2176b326dff435f1 to your computer and use it in GitHub Desktop.
Save sayakpaul/0211e1f2742159cc2176b326dff435f1 to your computer and use it in GitHub Desktop.
"""
Examples:
(1) python benchmark_controlnet_sdxl.py --controlnet_id diffusers/controlnet-depth-sdxl-1.0
(2) python benchmark_controlnet_sdxl.py --controlnet_id diffusers/controlnet-depth-sdxl-1.0-small
(3) python benchmark_controlnet_sdxl.py --controlnet_id diffusers/controlnet-depth-sdxl-1.0-mid
"""
import argparse
import time
import torch
from diffusers import (AutoencoderKL, ControlNetModel,
StableDiffusionXLControlNetPipeline)
from diffusers.utils import load_image
PIPELINE_ID = "stabilityai/stable-diffusion-xl-base-1.0"
VAE_PATH = "madebyollin/sdxl-vae-fp16-fix"
NUM_ITERS_TO_RUN = 3
NUM_INFERENCE_STEPS = 25
NUM_IMAGES_PER_PROMPT = 4
DEPTH_IMAGE_URL = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/d_stormtrooper.png"
PROMPT = "stormtrooper lecture, photorealistic"
SEED = 0
def load_pipeline(controlnet_id):
controlnet = ControlNetModel.from_pretrained(
controlnet_id,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
use_auth_token=True,
).to("cuda")
vae = AutoencoderKL.from_pretrained(VAE_PATH, torch_dtype=torch.float16).to("cuda")
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
PIPELINE_ID,
controlnet=controlnet,
vae=vae,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
).to("cuda")
pipe.enable_model_cpu_offload()
return pipe
def run_inference(args):
torch.cuda.reset_peak_memory_stats()
pipe = load_pipeline(args.controlnet_id)
depth_image = load_image(DEPTH_IMAGE_URL)
start = time.time_ns()
for _ in range(NUM_ITERS_TO_RUN):
images = pipe(
PROMPT,
image=depth_image,
num_inference_steps=NUM_INFERENCE_STEPS,
num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
).images
end = time.time_ns()
mem_bytes = torch.cuda.max_memory_allocated()
mem_MB = int(mem_bytes / (10**6))
total_time = f"{(end - start) / 1e6:.1f}"
results = {
"controlnet_id": args.controlnet_id,
"total_time (ms)": total_time,
"memory (mb)": mem_MB,
}
return results
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--controlnet_id",
type=str,
default="diffusers/controlnet-depth-sdxl-1.0",
required=True,
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
results = run_inference(args)
print(results)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment