Created
August 26, 2022 11:49
-
-
Save htoyryla/8684c76c32ea9653fb0eab7b57bdb3a3 to your computer and use it in GitHub Desktop.
Simple script to read frames from a video and use as init images in stablediffusion
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# Usage: | |
# | |
# python vid2.py "prompt" video_in_path output_dir [strength scale seed] | |
# | |
# requires image_to_image.py from https://github.com/huggingface/diffusers/blob/main/examples/inference/image_to_image.py | |
# | |
from torch import autocast | |
import torch | |
import numpy as np | |
import requests | |
from PIL import Image | |
import sys | |
import os | |
from image_to_image import StableDiffusionImg2ImgPipeline, preprocess | |
from moviepy.editor import VideoFileClip | |
# load the pipeline | |
device = "cuda" | |
pipe = StableDiffusionImg2ImgPipeline.from_pretrained( | |
"CompVis/stable-diffusion-v1-4", | |
revision="fp16", | |
torch_dtype=torch.float16, | |
use_auth_token=True | |
).to(device) | |
ipath = sys.argv[2] | |
opath = sys.argv[3] | |
prompt = sys.argv[1] #"A fantasy landscape, trending on artstation" | |
argn = len(sys.argv) | |
start = 0 | |
s = 0.5 | |
g = 7.5 | |
if argn > 4: | |
s = float(sys.argv[4]) | |
if argn > 5: | |
g = float(sys.argv[5]) | |
if argn > 6: | |
seed = int(sys.argv[6]) | |
else: | |
seed = 1024 | |
clip = VideoFileClip(ipath) | |
if start > 0: | |
clip = clip.subclip(start, 0) | |
count = 0 | |
for frame in clip.iter_frames(): | |
npimg = np.array(frame, dtype=np.uint8) | |
init_image = Image.fromarray(npimg.astype('uint8'), 'RGB') | |
#init_image = Image.open(ipath+fn).convert("RGB") | |
init_image = init_image.resize((768, 512)) | |
init_image = preprocess(init_image) | |
generator = torch.Generator("cuda").manual_seed(seed) | |
with autocast("cuda"): | |
images = pipe(prompt=prompt, init_image=init_image, strength=s, guidance_scale=g, generator=generator)["sample"] | |
images[0].save(opath+"/frame-"+str(count)+".png") | |
count += 1 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment