Created
September 12, 2024 22:46
-
-
Save a-r-r-o-w/d070cce059ab4ceab3a9f289ff83c69c to your computer and use it in GitHub Desktop.
Simple script that uses Llama 3.1 as a caption generation, Flux-dev as an image generator and CogVideoX-I2V for image-to-video generation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse | |
import gc | |
import json | |
import os | |
import pathlib | |
import random | |
from typing import Any, Dict | |
os.environ["TORCH_LOGS"] = "+dynamo,recompiles,graph_breaks" | |
os.environ["TORCHDYNAMO_VERBOSE"] = "1" | |
import numpy as np | |
import torch | |
import transformers | |
from diffusers import CogVideoXImageToVideoPipeline, CogVideoXDPMScheduler, DiffusionPipeline | |
from diffusers.utils.logging import get_logger | |
from diffusers.utils import export_to_video | |
torch.set_float32_matmul_precision("high") | |
logger = get_logger(__name__) | |
SYSTEM_PROMPT = """ | |
You are part of a team of people that create videos using generative models. You use a video-generation model that can generate a video about anything you describe. | |
For example, if you respond with "A beautiful morning in the woods with the sun peaking through the trees", the video generation model will create a video of exactly as described. You task is to summarize the descriptions of videos provided to by users, and create details prompts to feed into the generative model. | |
There are a few rules to follow: | |
- You will only ever output a single video description per request. | |
- If the user mentions to summarize the prompt in [X] words, make sure to not exceed the limit. | |
You responses should just be the video generation prompt. Here are examples: | |
- "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting." | |
- "A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall" | |
""".strip() | |
USER_PROMPT = """ | |
Could you generate a prompt for a video generation model? Please limit the prompt to [{0}] words. | |
""".strip() | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--num_videos", type=int, default=5, help="Number of unique videos you would like to generate." | |
) | |
parser.add_argument( | |
"--caption_generator_model_id", | |
type=str, | |
default="meta-llama/Meta-Llama-3.1-8B-Instruct", | |
help="Caption generation model.", | |
) | |
parser.add_argument( | |
"--caption_generator_cache_dir", type=str, default=None, help="Cache directory for caption generation model." | |
) | |
parser.add_argument( | |
"--image_generator_model_id", type=str, default="black-forest-labs/FLUX.1-dev", help="Image generation model." | |
) | |
parser.add_argument( | |
"--image_generator_cache_dir", type=str, default=None, help="Cache directory for image generation model." | |
) | |
parser.add_argument( | |
"--image_generator_num_inference_steps", type=int, default=30, help="Caption generation model." | |
) | |
parser.add_argument("--guidance_scale", type=float, default=6, help="Guidance scale to be use for generation.") | |
parser.add_argument( | |
"--use_dynamic_cfg", | |
action="store_true", | |
help="Whether or not to use cosine dynamic guidance for generation [Recommended].", | |
) | |
parser.add_argument( | |
"--output_dir", | |
type=str, | |
default="outputs/", | |
help="Location where generated images and videos should be stored.", | |
) | |
parser.add_argument("--compile", action="store_true", help="Whether or not to compile the transformer of image and video generators.") | |
parser.add_argument( | |
"--enable_vae_tiling", action="store_true", help="Whether or not to use VAE tiling when encoding/decoding." | |
) | |
parser.add_argument("--seed", type=int, default=1337, help="Seed for reproducibility.") | |
return parser.parse_args() | |
def reset_memory(): | |
gc.collect() | |
torch.cuda.empty_cache() | |
torch.cuda.reset_peak_memory_stats() | |
torch.cuda.reset_accumulated_memory_stats() | |
def print_memory(id: str = ""): | |
memory = torch.cuda.memory_allocated() / 1024**3 | |
max_memory = torch.cuda.max_memory_allocated() / 1024**3 | |
max_reserved = torch.cuda.max_memory_reserved() / 1024**3 | |
print() | |
print(f"===== {id} =====") | |
print(f"{memory=:.2f}") | |
print(f"{max_memory=:.2f}") | |
print(f"{max_reserved=:.2f}") | |
print(f"===== =====") | |
print() | |
@torch.no_grad() | |
def main(args: Dict[str, Any]) -> None: | |
output_dir = pathlib.Path(args.output_dir) | |
os.makedirs(output_dir.as_posix(), exist_ok=True) | |
random.seed(args.seed) | |
np.random.seed(args.seed) | |
torch.manual_seed(args.seed) | |
torch.cuda.manual_seed_all(args.seed) | |
reset_memory() | |
caption_generator = transformers.pipeline( | |
"text-generation", | |
model=args.caption_generator_model_id, | |
device_map="auto", | |
model_kwargs={ | |
"local_files_only": True, | |
"cache_dir": args.caption_generator_cache_dir, | |
"torch_dtype": torch.bfloat16, | |
}, | |
) | |
print_memory("Initialized Caption Generator") | |
captions = [] | |
for i in range(args.num_videos): | |
num_words = random.choice([25, 50, 75, 100]) | |
user_prompt = USER_PROMPT.format(num_words) | |
messages = [ | |
{"role": "system", "content": SYSTEM_PROMPT}, | |
{"role": "user", "content": user_prompt}, | |
] | |
outputs = caption_generator(messages, max_new_tokens=226) | |
caption = outputs[0]["generated_text"][-1]["content"] | |
if caption.startswith("\"") and caption.endswith("\""): | |
caption = caption[1:-1] | |
captions.append(caption) | |
logger.info(f"Generated caption: {caption}") | |
print_memory("Caption Generation complete") | |
with open(output_dir / "captions.json", "w") as file: | |
json.dump(captions, file) | |
del caption_generator | |
reset_memory() | |
image_generator = DiffusionPipeline.from_pretrained( | |
args.image_generator_model_id, cache_dir=args.image_generator_cache_dir, torch_dtype=torch.bfloat16 | |
) | |
image_generator.to("cuda") | |
if args.compile: | |
image_generator.transformer = torch.compile(image_generator.transformer, mode="max-autotune", fullgraph=True) | |
if args.enable_vae_tiling: | |
image_generator.vae.enable_tiling() | |
print_memory("Initialized Image Generator") | |
images = [] | |
for index, caption in enumerate(captions): | |
image = image_generator( | |
prompt=caption, | |
height=480, | |
width=720, | |
num_inference_steps=args.image_generator_num_inference_steps, | |
guidance_scale=3.5, | |
).images[0] | |
filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_") | |
image.save(output_dir / f"{index}_{filename}.png") | |
images.append(image) | |
print_memory("Image Generator complete") | |
del image_generator | |
reset_memory() | |
video_generator = CogVideoXImageToVideoPipeline.from_pretrained( | |
"THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16 | |
) | |
video_generator.to("cuda") | |
if args.compile: | |
video_generator.transformer = torch.compile(video_generator.transformer, mode="max-autotune", fullgraph=True) | |
if args.enable_vae_tiling: | |
video_generator.vae.enable_tiling() | |
print_memory("Initialized Video Generator") | |
generator = torch.Generator().manual_seed(args.seed) | |
for index, (caption, image) in enumerate(zip(captions, images)): | |
video = video_generator( | |
image=image, | |
prompt=caption, | |
height=480, | |
width=720, | |
num_frames=49, | |
num_inference_steps=50, | |
guidance_scale=args.guidance_scale, | |
use_dynamic_cfg=args.use_dynamic_cfg, | |
generator=generator, | |
).frames[0] | |
filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_") | |
export_to_video(video, output_dir / f"{index}_{filename}.mp4", fps=8) | |
print_memory("Video Generator complete") | |
if __name__ == "__main__": | |
args = get_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Run: