Last active
April 28, 2024 19:51
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import argparse | |
from PIL import Image | |
from diffusers import AutoPipelineForText2Image | |
import torch | |
args = None | |
# Define the models | |
SDXL_MODEL = "/models/stable-diffusion/sdxl/photoStableXL.safetensors" | |
SD15_MODEL = "/YOUR-MODELS-PATH/skibidi-butt-15.safetensors" | |
LOOP_COUNT = 4 | |
# Define the pipelines | |
def get_pipeline(): | |
if "SDXL" in args.model: | |
from diffusers import ( | |
StableDiffusionXLImg2ImgPipeline, | |
AutoPipelineForText2Image, | |
) | |
return StableDiffusionXLImg2ImgPipeline.from_single_file( | |
SDXL_MODEL, torch_dtype=torch.float16 | |
).to("cuda"), AutoPipelineForText2Image.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=torch.float16, | |
variant="fp16", | |
use_safetensors=True, | |
).to( | |
"cuda" | |
) | |
else: | |
from diffusers import StableDiffusionImg2ImgPipeline | |
return ( | |
StableDiffusionImg2ImgPipeline.from_single_file( | |
SD15_MODEL, torch_dtype=torch.float16 | |
).to("cuda"), | |
None, | |
) | |
def generate(input_path, prompt, pipeline): | |
init_image = None | |
if input_path: | |
# Load the input image | |
init_image = Image.open(input_path).convert("RGB") | |
# init_image = init_image.resize((init_image.width * 2, init_image.height * 2)) | |
init_image = init_image.resize((1024, 1024)) | |
# Generate the output image | |
return pipeline( | |
prompt=prompt, | |
image=init_image, | |
strength=args.strength, | |
num_inference_steps=args.num_inference_steps, | |
guidance_scale=args.guidance_scale, | |
generator=( | |
torch.Generator(device="cuda").manual_seed(args.seed) if args.seed else None | |
), | |
).images[0] | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-i", "--init_image", type=str, required=False) | |
parser.add_argument( | |
"-o", "--output_path", type=str, required=False, default="output-frames" | |
) | |
parser.add_argument("-c", "--count", type=int, required=True) | |
parser.add_argument("-p", "--prompt", type=str, required=True) | |
parser.add_argument("-t", "--strength", type=float, default=0.65) | |
parser.add_argument("-s", "--seed", type=float, default=None) | |
parser.add_argument("-n", "--num_inference_steps", type=int, default=28) | |
parser.add_argument("-g", "--guidance_scale", type=int, default=10) | |
parser.add_argument( | |
"--skip-frame-generation", | |
action="store_true", | |
default=False, | |
help="Skip frame generation stage", | |
) | |
parser.add_argument( | |
"--gif", | |
action="store_true", | |
default=True, | |
help="Builds output.gif from frames", | |
) | |
parser.add_argument("-ni", "--no_interpolate", action="store_true", default=False) | |
parser.add_argument( | |
"-m", "--model", type=str, choices=["SDXL", "SD15"], default="SDXL" | |
) | |
args = parser.parse_args() | |
## frame generation | |
if not args.skip_frame_generation: | |
img2img_pipeline, text_pipeline = get_pipeline() | |
if not os.path.exists(args.output_path): | |
os.makedirs(args.output_path) | |
if args.init_image: | |
# get initial image | |
src_image = args.init_image | |
# copy the initial image to the output path as frame 0 | |
init_image = Image.open(src_image).convert("RGB") | |
init_image = init_image.resize((1024, 1024)) | |
init_image.save(os.path.join(args.output_path, f'frame_{"0".zfill(4)}.png')) | |
else: | |
# generate our first image from the prompt | |
outimage = generate(None, args.prompt, text_pipeline) | |
outpath = os.path.join(args.output_path, f'frame_{"0".zfill(4)}.png') | |
outimage.save(outpath) | |
src_image = outpath | |
# Generate the images | |
for frame_num in range(args.count + 1): | |
frame_id = str(frame_num + 1).zfill(4) | |
output_file_path = os.path.join(args.output_path, f"frame_{frame_id}.png") | |
print(f"Generating image for {src_image} to {output_file_path}...") | |
outimage = generate(src_image, args.prompt, img2img_pipeline) | |
outimage.save(output_file_path) | |
src_image = output_file_path | |
# post-processing | |
if True: # args.gif: | |
import imageio | |
images = [] | |
# we're skipping the first one since that's been rather overpowering | |
# in my tests; really the first handful tend to be like that, but | |
# the first one is the worst; so we'll skip it for moving pictures | |
for i in range(1, args.count + 1): | |
images.append( | |
imageio.imread( | |
os.path.join(args.output_path, f"frame_{str(i).zfill(4)}.png") | |
) | |
) | |
imageio.mimsave("output.gif", images, duration=1.0) | |
print("GIF saved to output.gif") | |
if not args.no_interpolate: | |
# check if ffmpeg is available | |
if os.system("ffmpeg -version") != 0: | |
print("ffmpeg not found, skipping interpolation") | |
exit(1) | |
os.system( | |
f"ffmpeg -i output.gif -filter \"minterpolate='fps=60'\" output-interpolated-once.webm" | |
) | |
os.system( | |
f'ffmpeg -stream_loop {LOOP_COUNT} -i output-interpolated-once.webm -c copy "{args.prompt}.webm"' | |
) | |
os.remove("output-interpolated-once.webm") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
thanks for sharing. I cannot get this code to run "as-is" using SD1.5, and I think the error has something to do with the fact that 'text_pipeline' is 'None' if the model is SD15.
(if "SDXL", 'text_pipeline' is an instance of 'AutoPipelineForText2Image'. Did you get it to work okay for SD1.5?
Anyhoo, I got it working by adding something along the lines of :
from diffusers import StableDiffusionImg2ImgPipeline
return (
StableDiffusionImg2ImgPipeline.from_single_file(
SD15_MODEL, torch_dtype=torch.float16
).to("cuda"),AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")
)
It's super slow but that's likely my 8Gb VRAM