Skip to content

Instantly share code, notes, and snippets.

@drisspg
Created August 27, 2024 19:24
Show Gist options
  • Save drisspg/16c060c93c069d779958ef1438dfc813 to your computer and use it in GitHub Desktop.
Save drisspg/16c060c93c069d779958ef1438dfc813 to your computer and use it in GitHub Desktop.
Repro
import torch
torch.set_float32_matmul_precision("high")
import torch.utils.benchmark as benchmark
from diffusers import DiffusionPipeline
import gc
# from torchao.quantization import (
# int4_weight_only,
# int8_weight_only,
# int8_dynamic_activation_int8_weight,
# quantize_,
# autoquant,
# )
from torchao.float8.inference import ActivationCasting, QuantConfig, quantize_to_float8
# from torchao.prototype.quant_llm import fp6_llm_weight_only
# from torchao.sparsity import sparsify_, int8_dynamic_activation_int8_semi_sparse_weight
from tabulate import tabulate
import argparse
import json
PROMPT = "Eiffel Tower was Made up of more than 2 million translucent straws to look like a cloud, with the bell tower at the top of the building, Michel installed huge foam-making machines in the forest to blow huge amounts of unpredictable wet clouds in the building's classic architecture."
PREFIXES = {
"stabilityai/stable-diffusion-3-medium-diffusers": "sd3",
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS": "pixart",
"fal/AuraFlow": "auraflow",
"black-forest-labs/FLUX.1-dev" : "flux",
}
def flush():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def bytes_to_giga_bytes(bytes):
return f"{(bytes / 1024 / 1024 / 1024):.3f}"
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
def load_pipeline(
ckpt_id: str,
fuse_attn_projections: bool,
compile: bool,
quantization: str,
sparsify: bool,
) -> DiffusionPipeline:
pipeline = DiffusionPipeline.from_pretrained(ckpt_id, torch_dtype=torch.bfloat16).to("cuda")
if fuse_attn_projections:
pipeline.transformer.fuse_qkv_projections()
pipeline.vae.fuse_qkv_projections()
if quantization == "autoquant" and compile:
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline.vae.to(memory_format=torch.channels_last)
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)
if not sparsify:
if quantization == "int8dq":
quantize_(pipeline.transformer, int8_dynamic_activation_int8_weight())
quantize_(pipeline.vae, int8_dynamic_activation_int8_weight())
elif quantization == "int8wo":
quantize_(pipeline.transformer, int8_weight_only())
quantize_(pipeline.vae, int8_weight_only())
elif quantization == "int4wo":
quantize_(pipeline.transformer, int4_weight_only())
quantize_(pipeline.vae, int4_weight_only())
elif quantization == "fp6":
quantize_(pipeline.transformer, fp6_llm_weight_only())
quantize_(pipeline.vae, fp6_llm_weight_only())
elif quantization == "fp8-static":
pipeline.transformer = quantize_to_float8(pipeline.transformer, QuantConfig(ActivationCasting.STATIC, torch.tensor([1.0], dtype=torch.float32, device="cuda")))
# pipeline.vae = quantize_to_float8(pipeline.vae, QuantConfig(ActivationCasting.DYNAMIC), module_filter_fn=module_fn)
elif quantization == "fp8-dynamic":
def module_fn(m, name):
print(name)
return True
pipeline.transformer = quantize_to_float8(pipeline.transformer, QuantConfig(ActivationCasting.DYNAMIC), module_filter_fn=module_fn)
# pipeline.vae = quantize_to_float8(pipeline.vae, QuantConfig(ActivationCasting.DYNAMIC), module_filter_fn=module_fn)
elif quantization == "autoquant":
pipeline.transformer = autoquant(pipeline.transformer)
pipeline.vae = autoquant(pipeline.vae)
if sparsify:
sparsify_(pipeline.transformer, int8_dynamic_activation_int8_semi_sparse_weight())
sparsify_(pipeline.vae, int8_dynamic_activation_int8_semi_sparse_weight())
if quantization != "autoquant" and compile:
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, fullgraph=True)
pipeline.vae.to(memory_format=torch.channels_last)
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)
pipeline.set_progress_bar_config(disable=True)
return pipeline
def run_inference(pipe, batch_size):
_ = pipe(
prompt=PROMPT,
num_images_per_prompt=batch_size,
generator=torch.manual_seed(2024),
)
def pretty_print_results(results, precision: int = 6):
def format_value(value):
if isinstance(value, float):
return f"{value:.{precision}f}"
return value
filtered_table = {k: format_value(v) for k, v in results.items()}
print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))
def run_benchmark(pipeline, args):
for _ in range(5):
run_inference(pipeline, batch_size=args.batch_size)
time = benchmark_fn(run_inference, pipeline, args.batch_size)
torch.cuda.empty_cache()
memory = bytes_to_giga_bytes(torch.cuda.memory_allocated()) # in GBs.
info = dict(
ckpt_id=args.ckpt_id,
batch_size=args.batch_size,
fuse=args.fuse_attn_projections,
compile=args.compile,
quantization=args.quantization,
sparsify=args.sparsify,
memory=memory,
time=time,
)
pretty_print_results(info)
return info
def serialize_artifacts(info: dict, pipeline, args):
ckpt_id = PREFIXES[args.ckpt_id]
prefix = f"ckpt@{ckpt_id}-bs@{args.batch_size}-fuse@{args.fuse_attn_projections}-compile@{args.compile}-quant@{args.quantization}-sparsify@{args.sparsify}"
info_file = f"{prefix}_info.json"
with open(info_file, "w") as f:
json.dump(info, f)
image = pipeline(
prompt=PROMPT,
num_images_per_prompt=args.batch_size,
generator=torch.manual_seed(0),
).images[0]
image.save(f"{prefix}.png")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_id", default="PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", type=str)
parser.add_argument("--fuse_attn_projections", action="store_true")
parser.add_argument("--compile", action="store_true")
parser.add_argument(
"--quantization",
default="None",
choices=["int8dq", "int8wo", "int4wo", "autoquant", "fp6", "fp8-static", "fp8-dynamic", "None"],
help="Which quantization technique to apply",
)
parser.add_argument("--sparsify", action="store_true")
parser.add_argument("--batch_size", default=1, type=int, choices=[1, 4, 8])
args = parser.parse_args()
flush()
pipeline = load_pipeline(
ckpt_id=args.ckpt_id,
fuse_attn_projections=args.fuse_attn_projections,
compile=args.compile,
quantization=args.quantization,
sparsify=args.sparsify,
)
info = run_benchmark(pipeline, args)
serialize_artifacts(info, pipeline, args)
# Commands .py
pip install diffusers tabulate
git clone https://github.com/pytorch/ao
git checkout update-rowwise-scaling
USE_CPP=0 psd
python benchmark_pixart.py --ckpt_id black-forest-labs/FLUX.1-dev --quantization fp8-static --compile
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment