Created
January 4, 2024 15:52
-
-
Save lucataco/f51350ab1e1eae7eee3a6f4247f5cce4 to your computer and use it in GitHub Desktop.
Run SVD locally
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# from cog import BasePredictor, Input, Path | |
import os | |
import cv2 | |
import time | |
import math | |
import torch | |
import numpy as np | |
from PIL import Image | |
from glob import glob | |
from typing import Optional | |
from omegaconf import OmegaConf | |
from einops import rearrange, repeat | |
from torchvision.transforms import ToTensor | |
from sgm.inference.helpers import embed_watermark | |
from sgm.util import default, instantiate_from_config | |
"""Exported from stability/ai generative-models """ | |
def get_unique_embedder_keys_from_conditioner(conditioner): | |
return list(set([x.input_key for x in conditioner.embedders])) | |
def get_batch(keys, value_dict, N, T, device): | |
batch = {} | |
batch_uc = {} | |
for key in keys: | |
if key == "fps_id": | |
batch[key] = ( | |
torch.tensor([value_dict["fps_id"]]) | |
.to(device) | |
.repeat(int(math.prod(N))) | |
) | |
elif key == "motion_bucket_id": | |
batch[key] = ( | |
torch.tensor([value_dict["motion_bucket_id"]]) | |
.to(device) | |
.repeat(int(math.prod(N))) | |
) | |
elif key == "cond_aug": | |
batch[key] = repeat( | |
torch.tensor([value_dict["cond_aug"]]).to(device), | |
"1 -> b", | |
b=math.prod(N), | |
) | |
elif key == "cond_frames": | |
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0]) | |
elif key == "cond_frames_without_noise": | |
batch[key] = repeat( | |
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0] | |
) | |
else: | |
batch[key] = value_dict[key] | |
if T is not None: | |
batch["num_video_frames"] = T | |
for key in batch.keys(): | |
if key not in batch_uc and isinstance(batch[key], torch.Tensor): | |
batch_uc[key] = torch.clone(batch[key]) | |
return batch, batch_uc | |
def load_model( | |
config: str, | |
device: str, | |
num_frames: int, | |
num_steps: int, | |
): | |
config = OmegaConf.load(config) | |
if device == "cuda": | |
config.model.params.conditioner_config.params.emb_models[ | |
0 | |
].params.open_clip_embedding_config.params.init_device = device | |
config.model.params.sampler_config.params.num_steps = num_steps | |
config.model.params.sampler_config.params.guider_config.params.num_frames = ( | |
num_frames | |
) | |
if device == "cuda": | |
with torch.device(device): | |
model = instantiate_from_config(config.model).to(device).eval() | |
else: | |
model = instantiate_from_config(config.model).to(device).eval() | |
return model | |
def predict( | |
input_image | |
) -> str: | |
"""Run a single prediction on the model""" | |
num_frames: Optional[int] = 14 | |
num_steps: Optional[int] = 25 | |
fps_id: int = 6 | |
motion_bucket_id: int = 127 | |
cond_aug: float = 0.02 | |
seed: int = 23 | |
decoding_t: int = 7 | |
device: str = "cuda" | |
output_folder: Optional[str] = "output/" | |
model_config = "svd.yaml" | |
model = load_model( | |
model_config, | |
device, | |
num_frames, | |
num_steps, | |
) | |
torch.manual_seed(seed) | |
input_img_path = input_image | |
output_path = None | |
with Image.open(input_img_path) as image: | |
if image.mode == "RGBA": | |
image = image.convert("RGB") | |
w, h = image.size | |
if h % 64 != 0 or w % 64 != 0: | |
width, height = map(lambda x: x - x % 64, (w, h)) | |
image = image.resize((width, height)) | |
print( | |
f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!" | |
) | |
image = ToTensor()(image) | |
image = image * 2.0 - 1.0 | |
image = image.unsqueeze(0).to(device) | |
H, W = image.shape[2:] | |
assert image.shape[1] == 3 | |
F = 8 | |
C = 4 | |
shape = (num_frames, C, H // F, W // F) | |
if (H, W) != (576, 1024): | |
print( | |
"WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`." | |
) | |
if motion_bucket_id > 255: | |
print( | |
"WARNING: High motion bucket! This may lead to suboptimal performance." | |
) | |
if fps_id < 5: | |
print("WARNING: Small fps value! This may lead to suboptimal performance.") | |
if fps_id > 30: | |
print("WARNING: Large fps value! This may lead to suboptimal performance.") | |
value_dict = {} | |
value_dict["motion_bucket_id"] = motion_bucket_id | |
value_dict["fps_id"] = fps_id | |
value_dict["cond_aug"] = cond_aug | |
value_dict["cond_frames_without_noise"] = image | |
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image) | |
value_dict["cond_aug"] = cond_aug | |
with torch.no_grad(): | |
with torch.autocast(device): | |
batch, batch_uc = get_batch( | |
get_unique_embedder_keys_from_conditioner(model.conditioner), | |
value_dict, | |
[1, num_frames], | |
T=num_frames, | |
device=device, | |
) | |
c, uc = model.conditioner.get_unconditional_conditioning( | |
batch, | |
batch_uc=batch_uc, | |
force_uc_zero_embeddings=[ | |
"cond_frames", | |
"cond_frames_without_noise", | |
], | |
) | |
for k in ["crossattn", "concat"]: | |
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) | |
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) | |
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) | |
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) | |
randn = torch.randn(shape, device=device) | |
additional_model_inputs = {} | |
additional_model_inputs["image_only_indicator"] = torch.zeros( | |
2, num_frames | |
).to(device) | |
additional_model_inputs["num_video_frames"] = batch["num_video_frames"] | |
def denoiser(input, sigma, c): | |
return model.denoiser( | |
model.model, input, sigma, c, **additional_model_inputs | |
) | |
samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) | |
model.en_and_decode_n_samples_a_time = decoding_t | |
samples_x = model.decode_first_stage(samples_z) | |
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) | |
os.makedirs(output_folder, exist_ok=True) | |
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) | |
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") | |
writer = cv2.VideoWriter( | |
video_path, | |
cv2.VideoWriter_fourcc(*"MP4V"), | |
fps_id + 1, | |
(samples.shape[-1], samples.shape[-2]), | |
) | |
output_path = video_path | |
samples = embed_watermark(samples) | |
vid = ( | |
(rearrange(samples, "t c h w -> t h w c") * 255) | |
.cpu() | |
.numpy() | |
.astype(np.uint8) | |
) | |
for frame in vid: | |
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
writer.write(frame) | |
writer.release() | |
return output_path | |
t1 = time.time() | |
for _ in range(10): | |
predict("fireworks.png") | |
t2 = time.time() | |
print("Inference took - ",t2-t1,"seconds") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment