Skip to content

Instantly share code, notes, and snippets.

@shreyshahi
Created March 3, 2024 18:04
Show Gist options
  • Save shreyshahi/93d269f1d3fa6c8270b4a07230450560 to your computer and use it in GitHub Desktop.
Save shreyshahi/93d269f1d3fa6c8270b4a07230450560 to your computer and use it in GitHub Desktop.
Simple code to make stable diffusion dream about cats
# Code inspired from https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
# slerp function is entirely lifted from the above gist.
import torch
from diffusers import DiffusionPipeline
import numpy as np
def interpolate(v1, v2, step, total_steps):
alpha = step / (total_steps - 1)
new_vector = (1 - alpha) * v1 + alpha * v2
return new_vector
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
""" helper function to spherically interpolate two arrays v1 v2 """
inputs_are_torch = False
input_device = None
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
input_device = v0.device
v0 = v0.cpu().numpy()
v1 = v1.cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(input_device)
return v2
def main():
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1",
safety_checker = None,
torch_dtype = torch.float16,
use_safetensors=True,
).to("cuda")
prompt = "photograph of a cat high quality"
folder = "cats"
max_frame_number = 5000
frame_number = 0
num_interpolated_frames = 500
quality = 90
latent_shape = (1, 4, 96, 96)
v1 = torch.randn(latent_shape)
v2 = torch.randn(latent_shape)
while frame_number < max_frame_number:
for i in range(num_interpolated_frames):
t = i * 1.0 / (num_interpolated_frames - 1.0)
v = slerp(t, v1, v2)
print(f"Creating and saving frame number {frame_number:06d}")
image = pipeline(prompt, latents = v.half()).images[0]
output_path = f"{folder}/{frame_number:06d}.jpg"
image.save(output_path, quality=quality)
frame_number += 1
v1 = v2
v2 = torch.randn(latent_shape)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment