-
-
Save dribnet/f2b39cc212d7c927c3c8734cbe45db6c to your computer and use it in GitHub Desktop.
hacky stablediffusion code for generating videos
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
draws many samples from a diffusion model by slerp'ing around | |
the noise space, and dumps frames to a directory. You can then | |
stitch up the frames with e.g.: | |
$ ffmpeg -r 10 -f image2 -s 512x512 -i out/frame%06d.jpg -vcodec libx264 -crf 10 -pix_fmt yuv420p test.mp4 | |
THIS FILE IS HACKY AND NOT CONFIGURABLE READ THE CODE, MAKE EDITS TO PATHS AND SETTINGS YOU LIKE | |
THIS FILE IS HACKY AND NOT CONFIGURABLE READ THE CODE, MAKE EDITS TO PATHS AND SETTINGS YOU LIKE | |
THIS FILE IS HACKY AND NOT CONFIGURABLE READ THE CODE, MAKE EDITS TO PATHS AND SETTINGS YOU LIKE | |
nice slerp def from @xsteenbrugge ty | |
you have to have access to stablediffusion checkpoints from https://huggingface.co/CompVis | |
and install all the other dependencies (e.g. diffusers library) | |
""" | |
from diffusers import StableDiffusionPipeline | |
from time import time | |
from PIL import Image | |
from einops import rearrange | |
import numpy as np | |
import torch | |
from torch import autocast | |
from torchvision.utils import make_grid | |
torch.manual_seed(42) | |
import os | |
HF_TOKEN = os.environ['HF_TOKEN'] | |
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", use_auth_token=HF_TOKEN) | |
torch_device = 'cuda' | |
pipe.unet.to(torch_device) | |
pipe.vae.to(torch_device) | |
pipe.text_encoder.to(torch_device) | |
print('w00t') | |
batch_size = 1 | |
height = 512 | |
width = 512 | |
prompt = ["ultrarealistic steam punk neural network machine in the shape of a brain, placed on a pedestal, covered with neurons made of gears. dramatic lighting. #unrealengine"] * 1 | |
text_input = pipe.tokenizer(prompt, padding=True, truncation=True, return_tensors="pt") | |
text_embeddings = pipe.text_encoder(text_input.input_ids.to(torch_device))[0] | |
@torch.no_grad() | |
def diffuse(text_embeddings, init, guidance_scale = 7.5): | |
# text_embeddings are n,t,d | |
max_length = text_embeddings.shape[1] | |
uncond_input = pipe.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt") | |
uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(torch_device))[0] | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
latents = init.clone() | |
num_inference_steps = 50 | |
pipe.scheduler.set_timesteps(num_inference_steps) | |
for t in pipe.scheduler.timesteps: | |
# predict the noise residual | |
latent_model_input = torch.cat([latents] * 2) # for cfg | |
noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] | |
# perform guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = pipe.scheduler.step(noise_pred, t, latents)["prev_sample"] | |
# post-process | |
latents = 1 / 0.18215 * latents | |
image = pipe.vae.decode(latents) | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.cpu().permute(0, 2, 3, 1).numpy() | |
return image | |
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): | |
if not isinstance(v0, np.ndarray): | |
inputs_are_torch = True | |
input_device = v0.device | |
v0 = v0.cpu().numpy() | |
v1 = v1.cpu().numpy() | |
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) | |
if np.abs(dot) > DOT_THRESHOLD: | |
v2 = (1 - t) * v0 + t * v1 | |
else: | |
theta_0 = np.arccos(dot) | |
sin_theta_0 = np.sin(theta_0) | |
theta_t = theta_0 * t | |
sin_theta_t = np.sin(theta_t) | |
s0 = np.sin(theta_0 - theta_t) / sin_theta_0 | |
s1 = sin_theta_t / sin_theta_0 | |
v2 = s0 * v0 + s1 * v1 | |
if inputs_are_torch: | |
v2 = torch.from_numpy(v2).to(input_device) | |
return v2 | |
# DREAM | |
# sample start | |
init1 = torch.randn((batch_size, pipe.unet.in_channels, height // 8, width // 8)).to(torch_device) | |
n = 0 | |
while True: | |
# sample destination | |
init2 = torch.randn((batch_size, pipe.unet.in_channels, height // 8, width // 8)).to(torch_device) | |
for i, t in enumerate(np.linspace(0, 1, 200)): | |
init = slerp(float(t), init1, init2) | |
image = diffuse(text_embeddings, init, guidance_scale=10.0) | |
im = Image.fromarray((image[0] * 255).astype(np.uint8)) | |
im.save('outputs/mov1/frame%06d.jpg' % n) | |
print('dreaming... ', n) | |
n += 1 | |
init1 = init2 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment