Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active September 10, 2024 02:04
Show Gist options
  • Save sayakpaul/cfaebd221820d7b43fae638b4dfa01ba to your computer and use it in GitHub Desktop.
Save sayakpaul/cfaebd221820d7b43fae638b4dfa01ba to your computer and use it in GitHub Desktop.
Minimal example to show how to run distributed inference from a set of prompts with diffusers and accelerate.
# Originally by jiwooya1000, put together together by sayakpaul.
# Documentation: https://huggingface.co/docs/diffusers/main/en/training/distributed_inference
"""
Run:
accelerate launch distributed_inference_diffusers.py --batch_size 8
# Enable memory optimizations for large models like SD3
accelerate launch distributed_inference_diffusers.py --batch_size 8 --low_mem=1
"""
from diffusers import DiffusionPipeline
from accelerate import Accelerator
from accelerate.utils import gather_object
from tqdm import tqdm
from datasets import load_dataset
import torch
import time
import os
import fire
START_TIME = time.strftime("%Y%m%d_%H%M%S")
DTYPE_MAP = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
def get_batches(items, batch_size):
num_batches = (len(items) + batch_size - 1) // batch_size
batches = []
for i in range(num_batches):
start_index = i * batch_size
end_index = min((i + 1) * batch_size, len(items))
batch = items[start_index:end_index]
batches.append(batch)
return batches
def main(
ckpt_id: str = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
save_dir: str = "./evaluation/examples",
seed: int = 1,
batch_size: int = 4,
num_inference_steps: int = 20,
guidance_scale: float = 4.5,
dtype: str = "fp16",
low_mem: int = 0,
):
pipeline = DiffusionPipeline.from_pretrained(ckpt_id, torch_dtype=DTYPE_MAP[dtype])
save_dir = save_dir + f"_{START_TIME}"
parti_prompts = load_dataset("nateraw/parti-prompts", split="train")
data_loader = get_batches(items=parti_prompts["Prompt"], batch_size=batch_size)
distributed_state = Accelerator()
if low_mem:
pipeline.enable_model_cpu_offload(gpu_id=distributed_state.device.index)
else:
pipeline = pipeline.to(distributed_state.device)
if distributed_state.is_main_process:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print(f"Directory '{save_dir}' created successfully.")
else:
print(f"Directory '{save_dir}' already exists.")
count = 0
for _, prompts_raw in tqdm(enumerate(data_loader), total=len(data_loader)):
input_prompts = []
with distributed_state.split_between_processes(prompts_raw) as prompts:
generator = torch.manual_seed(seed)
images = pipeline(
prompts, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
).images
input_prompts.extend(prompts)
distributed_state.wait_for_everyone()
images = gather_object(images)
input_prompts = gather_object(input_prompts)
if distributed_state.is_main_process:
for image, prompt in zip(images, input_prompts):
count += 1
temp_dir = os.path.join(save_dir, f"example_{count}")
os.makedirs(temp_dir)
prompt = "_".join(prompt.split())
image.save(f"image_{prompt}.png")
if distributed_state.is_main_process:
print(f">>> Image Generation Finished. Saved in {save_dir}")
if __name__ == "__main__":
fire.Fire(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment