Created
September 17, 2023 11:36
-
-
Save cloneofsimo/85f763e06b67815278180a7856a10fa6 to your computer and use it in GitHub Desktop.
preprocess-videos-latents
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import csv | |
import torch | |
import cv2 | |
import logging | |
from typing import Tuple, Any, List | |
from torch.utils.data import DataLoader, Dataset | |
from multiprocessing import Pool | |
from streaming import MDSWriter | |
import ImageReward as RM | |
from PIL import Image | |
from diffusers.models import AutoencoderKL | |
import numpy as np | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
import pandas as pd | |
import json | |
import time | |
# Initialize logging | |
logging.basicConfig(level=logging.INFO) | |
from streaming.base.format.mds.encodings import Encoding, _encodings | |
class bf16(Encoding): | |
def encode(self, obj: Any) -> bytes: | |
return obj.tobytes() | |
def decode(self, data: bytes) -> Any: | |
return np.frombuffer(data, np.float16) | |
_encodings['bf16'] = bf16 | |
def prepare_image(pil_image, w=512, h=512): | |
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) | |
arr = np.array(pil_image.convert("RGB")) | |
arr = arr.astype(np.float32) / 127.5 - 1 | |
arr = np.transpose(arr, [2, 0, 1]) | |
image = torch.from_numpy(arr).unsqueeze(0) | |
return image | |
class VideoDataset(Dataset): | |
def __init__(self, csv_file): | |
self.video_files = pd.read_csv(csv_file)['video_path'].to_list() | |
#print(self.video_files) | |
self.dataset_latency = [] | |
def __len__(self): | |
return len(self.video_files) | |
def __getitem__(self, idx): | |
start_time = time.time() | |
video_file = self.video_files[idx] | |
frames = self._load_frames_from_video(video_file) | |
second_image = Image.fromarray(frames[1]) | |
second_image = self._center_crop_square_resize(second_image, 512) | |
tiled_image = self._tile_frames(frames) | |
# Check if tiled_image has repeated same images | |
diff = np.array([np.abs(frames[i] - frames[i+1]).sum() for i in range(3)]) | |
if np.all(diff < 1e-5): | |
return None, None | |
self.dataset_latency.append(time.time() - start_time) | |
return second_image, prepare_image(tiled_image, 1024, 1024) | |
def _load_frames_from_video(self, video_path): | |
vid = cv2.VideoCapture(video_path) | |
total_frames = int(vid.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = int(vid.get(cv2.CAP_PROP_FPS)) | |
quarter_second_interval = fps // 4 | |
interval = max(1, quarter_second_interval) | |
frames = [] | |
for i in range(4): | |
vid.set(cv2.CAP_PROP_POS_FRAMES, i * interval) | |
ret, frame = vid.read() | |
if not ret: | |
break | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frames.append(frame_rgb) | |
vid.release() | |
return frames | |
def _center_crop_square_resize(self, image, output_size): | |
# Resize image while maintaining aspect ratio | |
aspect = image.width / image.height | |
if aspect > 1: | |
# Landscape orientation - wide image | |
width = int(output_size * aspect) | |
height = output_size | |
else: | |
# Portrait orientation - tall image | |
width = output_size | |
height = int(output_size / aspect) | |
image = image.resize((width, height), Image.BICUBIC) | |
# Center crop to 512x512 | |
left = (image.width - output_size) / 2 | |
top = (image.height - output_size) / 2 | |
right = (image.width + output_size) / 2 | |
bottom = (image.height + output_size) / 2 | |
return image.crop((left, top, right, bottom)) | |
def _tile_frames(self, frames): | |
tiled_image = Image.new('RGB', (512 * 2, 512 * 2)) | |
# Assuming all frames are of the same size | |
for i, frame in enumerate(frames): | |
pil_frame = Image.fromarray(frame) | |
# Center-square resize & center-crop each frame before pasting | |
pil_frame = self._center_crop_square_resize(pil_frame, 512) | |
tiled_image.paste(pil_frame, (512 * (i % 2), 512 * (i // 2))) | |
return tiled_image | |
@torch.no_grad() | |
def convert_to_mds(args: Tuple[List[str], torch.device]): | |
sub_out_roots, device = args | |
logging.info(f"Processing on {device}") | |
logging.info(f"Processing {sub_out_roots}") | |
# Set the device for the current process | |
torch.cuda.set_device(device) | |
# Initialize the models | |
image_reward_model = RM.load("ImageReward-v1.0").to(device).eval() | |
# vae model | |
vae_model = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").half() | |
vae_model = vae_model.to(device).eval() | |
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") | |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(device).eval() | |
# Load the dataset | |
for sub_out_root in sub_out_roots: | |
dataset = VideoDataset(os.path.join(sub_out_root, 'data.csv')) | |
sub_data_root = os.path.join(sub_out_root, 'data') | |
columns = { | |
'reward_output': 'float32', | |
'vae_output': 'bf16', | |
'caption_output': 'str' | |
} | |
if os.path.exists(sub_data_root): | |
# remove all files in the directory | |
for file in os.listdir(sub_data_root): | |
os.remove(os.path.join(sub_data_root, file)) | |
os.makedirs(sub_data_root, exist_ok=True) | |
with MDSWriter(out=sub_data_root, columns=columns) as out: | |
inference_latencies = [] | |
for data_mid, data_all in dataset: | |
if data_mid is None: | |
continue | |
data_all = data_all.to(device) | |
start_time = time.time() | |
blip_inputs = blip_processor(data_mid, text = " ", return_tensors="pt").to(torch.float16).to(device) | |
blip_out = blip_model.generate(**blip_inputs, max_new_tokens=25, min_new_tokens=4, do_sample=True, top_k=50, temperature=0.7) | |
generated_captions = blip_processor.batch_decode(blip_out, skip_special_tokens=True)[0] | |
reward_output = image_reward_model.score(generated_captions, data_mid) | |
vae_output = vae_model.encode(data_all.half()).latent_dist.sample() | |
print(reward_output, vae_output.shape, generated_captions) | |
#Save the outputs to MDS | |
sample = { | |
'reward_output': reward_output, | |
'vae_output': vae_output.cpu().half().numpy(), | |
'caption_output': generated_captions, | |
} | |
out.write(sample) | |
inference_latencies.append(time.time() - start_time) | |
print(f"Average Inference Latency on {device}: {np.mean(inference_latencies)} seconds") | |
print(f"Average Dataset Processing Latency {device}: {np.mean(dataset.dataset_latency)} seconds") | |
return True | |
def init_worker(): | |
pid = os.getpid() | |
print(f'\nInitialize Worker PID: {pid}', flush=True, end='') | |
def main(video_files: List[str], out_root): | |
# Group into batches of 47 | |
grouped_datasets = [video_files[i:i + 1024] for i in range(0, len(video_files), 1024)] | |
# Make sure we have enough groups for our GPUs | |
num_gpus = torch.cuda.device_count() | |
assert len(grouped_datasets) >= num_gpus, f"Not enough data for {num_gpus} GPUs." | |
# Preprocess videos to CSV | |
os.makedirs(out_root, exist_ok=True) | |
grouped_paths = [] | |
for i, dataset_group in enumerate(grouped_datasets): | |
group_path = os.path.join(out_root, f'group_{i}') | |
os.makedirs(group_path, exist_ok=True) | |
df = pd.DataFrame(dataset_group, columns=['video_path']) | |
csvpath = os.path.join(group_path, 'data.csv') | |
df.to_csv(csvpath, index=False) | |
grouped_paths.append(group_path) | |
print(grouped_paths) | |
logging.info("Videos preprocessed to CSV files.") | |
# Create a round-robin GPU assignment | |
gpu_assignments = [([grouped_paths[i]], torch.device(f'cuda:{i % num_gpus}')) for i in range(len(grouped_paths))] | |
with Pool(num_gpus, initializer=init_worker) as pool: | |
pool.map(convert_to_mds, gpu_assignments) | |
print('Finished') |
Author
cloneofsimo
commented
Sep 17, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment