Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active July 30, 2025 13:21
Show Gist options
  • Save sayakpaul/91fa328e949c71dc4420ebb50eb35ca3 to your computer and use it in GitHub Desktop.
Save sayakpaul/91fa328e949c71dc4420ebb50eb35ca3 to your computer and use it in GitHub Desktop.
Benchmarking code for the "torch.compile() in Diffusers blog post"
# Make sure you are using the latest `bitsandbytes` (at least 0.46.0) and PyTorch nightlies (at least 2.8).
# Put together by sayakpaul and anijain2305
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers import FluxPipeline
import argparse
import json
import torch
import time
from functools import partial
import torch.utils.benchmark as benchmark
def get_bnb_config(torch_dtype=torch.bfloat16):
quant_kwargs = {"load_in_4bit": True, "bnb_4bit_compute_dtype": torch_dtype, "bnb_4bit_quant_type": "nf4"}
quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs=quant_kwargs,
components_to_quantize=["transformer", "text_encoder_2"],
)
return quant_config
def initialize_pipeline(
quant_config=None, torch_dtype=torch.bfloat16, compile=False, regional_compile=False, model_offload=False
):
ckpt_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(
ckpt_id,
torch_dtype=torch_dtype,
quantization_config=quant_config,
)
if model_offload:
pipe.enable_model_cpu_offload()
else:
pipe = pipe.to("cuda")
compiling = compile or regional_compile
if quant_config is not None and compiling:
torch._dynamo.config.capture_dynamic_output_shape_ops = True
if compiling:
if model_offload:
if compile:
pipe.transformer.compile()
elif regional_compile:
pipe.transformer.compile_repeated_blocks(fullgraph=True)
else:
if compile:
pipe.transformer.compile(fullgraph=True)
elif regional_compile:
# No cudagraphs when using regional compilation
pipe.transformer.compile_repeated_blocks(fullgraph=True)
for name, module in pipe.components.items():
if isinstance(module, torch.nn.Module):
print(name, f"{module.device=}")
pipe.set_progress_bar_config(disable=True)
return pipe
def benchmark_fn(func_to_benchmark):
t0 = benchmark.Timer(
stmt="func_to_benchmark()",
globals={"func_to_benchmark": func_to_benchmark},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
def get_pipe_kwargs(num_inference_steps=28):
pipe_kwargs = {
"prompt": "A cat holding a sign that says hello world",
"height": 1024,
"width": 1024,
"guidance_scale": 3.5,
"num_inference_steps": num_inference_steps,
"max_sequence_length": 512,
"generator": torch.manual_seed(0),
}
return pipe_kwargs
def run_inference(pipe, pipe_kwargs):
_ = pipe(**pipe_kwargs)
def run_compile_time_inference(pipe):
t0 = time.perf_counter()
for _ in range(2):
run_inference(pipe, get_pipe_kwargs(num_inference_steps=1))
t1 = time.perf_counter()
compile_time = t1 - t0
t0 = time.perf_counter()
for _ in range(2):
run_inference(pipe, get_pipe_kwargs(num_inference_steps=1))
t1 = time.perf_counter()
baseline_time = t1 - t0
return compile_time - baseline_time
def main(args):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats("cuda")
torch.cuda.reset_accumulated_memory_stats("cuda")
quant_config = get_bnb_config() if args.do_quant else None
pipe = initialize_pipeline(
quant_config=quant_config,
compile=args.compile,
regional_compile=args.regional_compile,
model_offload=args.model_offload,
)
pipe_kwargs = get_pipe_kwargs()
torch.compiler.reset()
with torch._inductor.utils.fresh_inductor_cache():
# Compile time benchmarking - Run the inference twice
compile_time = 0.0
if args.compile or args.regional_compile:
compile_time = run_compile_time_inference(pipe)
for _ in range(2):
run_inference(pipe, pipe_kwargs)
inference_func = partial(run_inference, pipe, pipe_kwargs)
latency = float(benchmark_fn(inference_func))
inference_memory = round(torch.cuda.max_memory_allocated() / (1024**3), 3)
image = pipe(**get_pipe_kwargs()).images[0]
artifact_dict = {"time": latency, "memory": inference_memory, "compile_time": compile_time}
artifact_dict.update(vars(args))
file_prefix = f"comp@{args.compile}-reg_comp@{args.regional_compile}-quant@{args.do_quant}-mo@{args.model_offload}"
image.save(f"{file_prefix}.png")
with open(f"{file_prefix}.json", "w") as f:
json.dump(artifact_dict, f)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--do_quant", action="store_true")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--regional_compile", action="store_true")
parser.add_argument("--model_offload", action="store_true")
args = parser.parse_args()
main(args)

nvidia-smi:

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:53:00.0 Off |                    0 |
| N/A   33C    P0             71W /  700W |       1MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

diffusers-cli:

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

- 🤗 Diffusers version: 0.35.0.dev0
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.10.14
- PyTorch version (GPU?): 2.8.0.dev20250604+cu126 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.30.2
- Transformers version: 4.52.0.dev0
- Accelerate version: 1.6.0
- PEFT version: 0.15.2
- Bitsandbytes version: 0.46.0
- Safetensors version: 0.4.5
- xFormers version: not installed
- Accelerator: NVIDIA H100 80GB HBM3, 81559 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
python compile_script_blog_post.py
python compile_script_blog_post.py --compile
python compile_script_blog_post.py --regional_compile
python compile_script_blog_post.py --model_offload
python compile_script_blog_post.py --model_offload --compile
python compile_script_blog_post.py --model_offload --regional_compile
python compile_script_blog_post.py --do_quant
python compile_script_blog_post.py --do_quant --compile
python compile_script_blog_post.py --do_quant --regional_compile
python compile_script_blog_post.py --do_quant --model_offload
python compile_script_blog_post.py --do_quant --model_offload --compile
python compile_script_blog_post.py --do_quant --model_offload --regional_compile
time_S memory_GB compile_time_S do_quant compile regional_compile model_offload
0 4.473 33.855 67.4496 False True False False
1 4.501 33.851 9.55139 False False True False
2 5.024 14.973 110.863 True True False False
3 5.048 14.999 11.4292 True False True False
4 6.669 33.851 0 False False False False
5 7.279 14.968 0 True False False False
6 9.822 12.237 13.1464 True False True True
7 9.85 12.227 120.305 True True False True
8 12.224 12.237 0 True False False True
9 17.593 22.706 79.2352 False True False True
10 18.658 22.55 10.1685 False False True True
11 21.488 22.659 0 False False False True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment