Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Last active April 18, 2025 07:35
Show Gist options
  • Save a-r-r-o-w/4d9732d17412888c885480c6521a9897 to your computer and use it in GitHub Desktop.
Save a-r-r-o-w/4d9732d17412888c885480c6521a9897 to your computer and use it in GitHub Desktop.
Demonstrates how to use CogVideoX 2B/5B with Diffusers and TorchAO
#!/bin/bash
compile_flags=("" "--compile")
fuse_qkv_flags=("" "--fuse_qkv")
# quantizations=("fp16" "bf16" "fp8" "fp8_e4m3" "fp8_e5m2" "fp6" "int8wo" "int8dq" "int4dq" "int4wo" "autoquant" "sparsify")
quantizations=("fp16" "bf16" "fp6" "int8wo" "int8dq" "int4dq" "int4wo" "autoquant" "sparsify")
device="cuda"
# Check if completed.txt exists and read it into an array
if [ -f completed.txt ]; then
mapfile -t completed_runs < completed.txt
else
completed_runs=()
fi
for quantization in "${quantizations[@]}"; do
for compile in "${compile_flags[@]}"; do
for fuse_qkv in "${fuse_qkv_flags[@]}"; do
cmd="python3 cogvideox-torchao-benchmark.py $compile $fuse_qkv --dtype $quantization --device $device"
# Check if the command is in the list of completed runs
if [[ " ${completed_runs[@]} " =~ " ${cmd} " ]]; then
echo "Skipping already completed command: $cmd"
continue
fi
echo "Running command: $cmd"
eval $cmd
echo -ne "------------------ Finished executing script ------------------\n\n"
done
done
done
import argparse
import gc
import os
import time
os.environ["TORCH_LOGS"] = "dynamo,output_code,graph_breaks,recompiles"
import torch
import torch.utils.benchmark as benchmark
from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, CogVideoXDDIMScheduler
from diffusers.utils import export_to_video
from tabulate import tabulate
from transformers import T5EncoderModel
from torchao.quantization import (
autoquant,
quantize_,
int8_weight_only,
int8_dynamic_activation_int8_weight,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_semi_sparse_weight,
int4_weight_only,
)
from torchao.sparsity import sparsify_
from torchao.float8.inference import ActivationCasting, QuantConfig, quantize_to_float8
from torchao.prototype.quant_llm import fp6_llm_weight_only
torch.set_float32_matmul_precision("high")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
DTYPE_CONVERTER = {
"fp32": lambda module: module.to(dtype=torch.float32),
"fp16": lambda module: module.to(dtype=torch.float16),
"bf16": lambda module: module.to(dtype=torch.bfloat16),
"fp8": lambda module: quantize_to_float8(module, QuantConfig(ActivationCasting.DYNAMIC)),
"fp8_e4m3": lambda module: module.to(dtype=torch.float8_e4m3fn),
"fp8_e5m2": lambda module: module.to(dtype=torch.float8_e5m2),
"fp6": lambda module: quantize_(module, fp6_llm_weight_only()),
"int8wo": lambda module: quantize_(module, int8_weight_only()),
"int8dq": lambda module: quantize_(module, int8_dynamic_activation_int8_weight()),
"int4dq": lambda module: quantize_(module, int8_dynamic_activation_int4_weight()),
"int4wo": lambda module: quantize_(module, int4_weight_only()),
"autoquant": lambda module: autoquant(module, error_on_unseen=False),
"sparsify": lambda module: sparsify_(module, int8_dynamic_activation_int8_semi_sparse_weight()),
}
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
def reset_memory(device):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.reset_accumulated_memory_stats(device)
def print_memory(device):
memory = torch.cuda.memory_allocated(device) / 1024**3
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
print(f"{memory=:.3f}")
print(f"{max_memory=:.3f}")
print(f"{max_reserved=:.3f}")
def pretty_print_results(results, precision: int = 6):
def format_value(value):
if isinstance(value, float):
return f"{value:.{precision}f}"
return value
filtered_table = {k: format_value(v) for k, v in results.items()}
print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))
def load_pipeline(model_id, dtype, device, quantize_vae, compile, fuse_qkv):
# 1. Load pipeline
pipe = CogVideoXPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.set_progress_bar_config(disable=True)
if fuse_qkv:
pipe.fuse_qkv_projections()
# 2. Quantize and compile
if dtype == "autoquant" and compile:
pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
# VAE cannot be compiled due to: https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f#file-test_cogvideox_torch_compile-py-L30
text_encoder_return = DTYPE_CONVERTER[dtype](pipe.text_encoder)
transformer_return = DTYPE_CONVERTER[dtype](pipe.transformer)
vae_return = None
if dtype in ["fp32", "fp16", "bf16", "fp8_e4m3", "fp8_e5m2"] or quantize_vae:
vae_return = DTYPE_CONVERTER[dtype](pipe.vae)
if text_encoder_return is not None:
pipe.text_encoder = text_encoder_return
if transformer_return is not None:
pipe.transformer = transformer_return
if vae_return is not None:
pipe.vae = vae_return
if dtype != "autoquant" and compile:
pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
# VAE cannot be compiled due to: https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f#file-test_cogvideox_torch_compile-py-L30
return pipe
def run_inference(pipe):
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(
prompt=prompt,
guidance_scale=6,
num_inference_steps=50,
generator=torch.Generator().manual_seed(3047), # https://arxiv.org/abs/2109.08203
)
return video
def main(dtype, device, quantize_vae, compile, fuse_qkv):
# 1. Load pipeline
# model_id = "THUDM/CogVideoX-5b" # or "THUDM/CogVideoX-2b"
model_id = "THUDM/CogVideoX-5b"
pipe = load_pipeline(model_id, dtype, device, quantize_vae, compile, fuse_qkv)
reset_memory(device)
print_memory(device)
torch.cuda.empty_cache()
model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
# 2. Warmup
num_warmups = 2
for _ in range(num_warmups):
video = run_inference(pipe)
# 3. Benchmark
time = benchmark_fn(run_inference, pipe)
print_memory(device)
torch.cuda.empty_cache()
inference_memory = round(torch.cuda.max_memory_allocated() / 1024**3, 3)
# 4. Save results
model_type = "5B" if "5b" in model_id else "2B"
info = {
"model_type": model_type,
"compile": compile,
"fuse_qkv": fuse_qkv,
"quantize_vae": quantize_vae,
"quantization": dtype,
"model_memory": model_memory,
"inference_memory": inference_memory,
"time": time,
}
pretty_print_results(info, precision=3)
export_to_video(
video.frames[0], f"output-quantization_{dtype}-compile_{compile}-fuse_qkv_{fuse_qkv}-{model_type}.mp4", fps=8
)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dtype",
type=str,
default="fp16",
choices=[
"fp32",
"fp16",
"bf16",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
"fp6",
"int8wo",
"int8dq",
"int4dq",
"int4wo",
"autoquant",
"sparsify",
],
)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--quantize_vae", action="store_true", default=False)
parser.add_argument("--compile", action="store_true", default=False)
parser.add_argument("--fuse_qkv", action="store_true", default=False)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
main(args.dtype, args.device, args.quantize_vae, args.compile, args.fuse_qkv)
# Install torchao from source and Pytorch Nightly
# Other environments have not yet been tested.
import tempfile
import torch
from diffusers import CogVideoXTransformer3DModel, CogVideoXPipeline
from diffusers.utils import export_to_video
from torchao.quantization import (
quantize_,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_weight_only,
int8_dynamic_activation_int8_weight,
)
# Either "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
model_id = "THUDM/CogVideoX-5b"
# 1. Quantize and save the transformer
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
quantize_(transformer, int8_weight_only())
with tempfile.NamedTemporaryFile() as file:
torch.save(transformer.state_dict(), file)
file.seek(0)
state_dict = torch.load(file, map_location="cpu")
# 2. Create new model and load quantized state dict
transformer = CogVideoXTransformer3DModel.from_config(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
transformer.load_state_dict(state_dict, assign=True, strict=True)
# 3. Create pipeline and run inference
pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16).to("cuda")
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(
prompt=prompt,
guidance_scale=6,
use_dynamic_cfg=True,
num_inference_steps=50,
generator=torch.Generator().manual_seed(3047), # https://arxiv.org/abs/2109.08203
).frames[0]
export_to_video(video, "output.mp4", fps=8)
@a-r-r-o-w
Copy link
Author

@a-r-r-o-w
Copy link
Author

The following results are from an H100.

model_type compile fuse_qkv quantize_vae quantization model_memory inference_memory time
5B False True False fp16 21.978 33.988 113.945
5B True True False fp16 21.979 33.99 87.155
5B False True False bf16 21.979 33.988 112.398
5B True True False bf16 21.979 33.987 87.455
5B False True False fp8 11.374 23.383 113.167
5B True True False fp8 11.374 23.383 75.255
5B False True False int8wo 11.414 23.422 123.144
5B True True False int8wo 11.414 23.423 87.026
5B True True False int8dq 11.412 59.355 78.945
5B False True False int4dq 12.785 24.793 151.242
5B True True False int4dq 12.785 24.795 87.403
5B False True False int4wo 6.824 18.829 667.125

@a-r-r-o-w
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment