Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Created September 6, 2024 15:19
Show Gist options
  • Save a-r-r-o-w/4dee20250e82f4e44690a02351324a4a to your computer and use it in GitHub Desktop.
Save a-r-r-o-w/4dee20250e82f4e44690a02351324a4a to your computer and use it in GitHub Desktop.
Extremely simple and naive video captioning with MiniCPM-V-2_6 and Llama-3.1-8B-Instruct
import argparse
import os
import random
from typing import Any, Dict
import numpy as np
import pandas as pd
import torch
import transformers
from decord import VideoReader, cpu
from transformers import AutoModel, AutoTokenizer
from PIL import Image
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--instance_data_root", type=str, required=True, help="Base folder where video files are located."
)
parser.add_argument("--cache_dir", type=str, default=None, help="Path to where models are stored.")
parser.add_argument(
"--output_path", type=str, default="video_dataset.csv", help="File path where dataset csv should be stored."
)
parser.add_argument("--max_new_tokens", type=int, default=226, help="Maximum number of new tokens to generate.")
parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility")
return parser.parse_args()
SYSTEM_PROMPT = """
You are part of a team of people that create videos using generative models. You use a video-generation model that can generate a video about anything you describe.
For example, if you respond with "A beautiful morning in the woods with the sun peaking through the trees", the video generation model will create a video of exactly as described. You task is to summarize the descriptions of videos provided to by users, and create details prompts to feed into the generative model.
There are a few rules to follow:
- You will only ever output a single video description per request.
- If the user mentions to summarize the prompt in [X] words, make sure to not exceed the limit.
You responses should just be the video generation prompt. Here are examples:
- "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting."
- "A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall"
""".strip()
USER_PROMPT = """
Could you generate a prompt for a video generation model for the following video summary:
```
{0}
```
Please limit the prompt to [{1}] words.
""".strip()
QUESTION = """
Describe the video. You should pay close attention to every detail in the video and describe it in as much detail as possible.
""".strip()
MAX_NUM_FRAMES = 49
def set_seed(seed: int):
"""
Args:
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
seed (`int`): The seed to set.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def encode_video(video_path: str):
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps())
sample_fps //= 2
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
print("num frames:", len(frames))
return frames
@torch.no_grad()
def main(args: Dict[str, Any]):
set_seed(args.seed)
video_files = [file for file in os.listdir(args.instance_data_root) if file.endswith(".mp4")]
video_files = [os.path.join(args.instance_data_root, file) for file in video_files]
video_descriptions = {}
model_id = "openbmb/MiniCPM-V-2_6"
model = AutoModel.from_pretrained(
model_id, trust_remote_code=True, attn_implementation="sdpa", torch_dtype=torch.bfloat16
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
for filepath in video_files:
print(f"Generating video summary for file: `{filepath}`")
frames = encode_video(filepath)
msgs = [{"role": "user", "content": frames + [QUESTION]}]
params = {
"use_image_id": False,
"max_slice_nums": 10,
}
description = model.chat(image=None, msgs=msgs, tokenizer=tokenizer, **params)
video_descriptions[filepath] = {"summary": description}
del model
del tokenizer
model = None
tokenizer = None
torch.cuda.empty_cache()
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
pipeline = transformers.pipeline(
"text-generation",
model=model_id,
device_map="auto",
model_kwargs={
"local_files_only": True,
"cache_dir": args.cache_dir,
"torch_dtype": torch.bfloat16,
},
)
for filepath in video_files:
print(f"Generating captions for file: `{filepath}`")
for prompt_type, num_words in [("short_prompt", 25), ("prompt", 75), ("verbose_prompt", 125)]:
user_prompt = USER_PROMPT.format(video_descriptions[filepath]["summary"], num_words)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
]
outputs = pipeline(
messages,
max_new_tokens=226,
)
video_descriptions[filepath][prompt_type] = outputs[0]["generated_text"][-1]["content"]
df_data = []
for filepath, description in video_descriptions.items():
relative_path = os.path.relpath(filepath, args.instance_data_root)
df_data.append(
{
"path": relative_path,
"short_prompt": description.get("short_prompt", ""),
"prompt": description.get("prompt", ""),
"verbose_prompt": description.get("verbose_prompt", ""),
"summary": description.get("summary", ""),
}
)
df = pd.DataFrame(df_data)
df.to_csv(args.output_path)
if __name__ == "__main__":
args = get_args()
main(args)
@a-r-r-o-w
Copy link
Author

Usage:

python3 workflows/caption_dataset.py --instance_data_root <PATH_TO_DATA> --cache_dir <PATH_TO_MODEL_CACHE_DIR> --output_path <PATH_TO_DATASET_SAVEDIR>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment