Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Last active November 11, 2024 04:02
Show Gist options
  • Save a-r-r-o-w/31be62828b00a9292821b85c1017effa to your computer and use it in GitHub Desktop.
Save a-r-r-o-w/31be62828b00a9292821b85c1017effa to your computer and use it in GitHub Desktop.
Demonstrates how to use CogVideoX 2B/5B with Diffusers and Optimum-Quanto
import gc
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
from optimum.quanto import freeze, quantize, qfloat8, qfloat8_e4m3fn, qfloat8_e5m2, qint8, qint4, qint2
def reset_memory(device):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.reset_accumulated_memory_stats(device)
def print_memory(device):
memory = torch.cuda.memory_allocated(device) / 1024**3
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
print(f"{memory=:.3f}")
print(f"{max_memory=:.3f}")
print(f"{max_reserved=:.3f}")
# Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
# For 5B, bfloat16 is the ideal dtype. For 2B, float16 is ideal
device = "cuda"
dtype = torch.bfloat16
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=dtype)
pipe.to(device)
# Optionally, enable memory savings
# pipe.enable_model_cpu_offload()
# pipe.vae.enable_tiling()
# Note that pipe.to(device) can be removed when enabling cpu offloading to save 1 round-trip from cpu to cuda
reset_memory(device)
print("===== Model memory =====")
print_memory(device)
# Weights-only quantization
quantize(pipe.transformer, weights=qfloat8)
quantize(pipe.vae, weights=qfloat8)
freeze(pipe.transformer)
freeze(pipe.vae)
reset_memory(device)
print("===== Quantized model memory =====")
print_memory(device)
prompt = "Photorealisitc movie trailer, urban city with high-rise buildings, tracking shot of young man driving a cycle, intricate details photographed by professional directors"
video = pipe(prompt=prompt, guidance_scale=6, use_dynamic_cfg=True, num_inference_steps=50).frames[0]
print("===== Inference memory =====")
print_memory(device)
export_to_video(video, "output.mp4", fps=8)
# Combining quantization, VAE tiling and cpu offloading can result in < 10 GB usage!
@a-r-r-o-w
Copy link
Author

a-r-r-o-w commented Aug 24, 2024

The following results are from an A100, 80 GB.

Model type Quantization (weights/activation) Compiled Model memory Inference Memory Time
5B bfloat16/bfloat16 False 19.760 31.742 245
5B qint8/bfloat16 False 14.582 26.569 252
5B qfloat8/bfloat16 False 14.585 26.575 251
2B bfloat16/bfloat16 False 12.550 24.528 87
2B qint8/bfloat16 False 10.934 22.910 91
2B qfloat8/bfloat16 False 10.939 22.918 91

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment