Skip to content

Instantly share code, notes, and snippets.

@htoyryla
Last active September 21, 2022 07:18
Show Gist options
  • Save htoyryla/698881b7098514b58de774a3d1562e6c to your computer and use it in GitHub Desktop.
Save htoyryla/698881b7098514b58de774a3d1562e6c to your computer and use it in GitHub Desktop.
Stable diffusion tool, image synthesis, img2img and prompt weights
# stable diffusion tool
# @htoyryla github twitter instagram
# requires diffusers 0.3.0 and a trained model
# relies heavily on code from https://github.com/huggingface/diffusers
# for prompt weighting use prompts like "a subprompt like this:10 / in the style of that:25"
import torch
from torchvision.utils import save_image
from torchvision import transforms
import torch.nn.functional as F
from torch import autocast
import PIL
from PIL import Image
import numpy as np
from tqdm import tqdm
import os
import argparse
import inspect
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler
from transformers import CLIPTextModel, CLIPTokenizer
parser = argparse.ArgumentParser()
parser.add_argument('--text', type=str, default="", help='text prompt')
parser.add_argument('--image', type=str, default="", help='path to init image')
parser.add_argument('--model', type=str, default="./stable-diffusion-v1-4", help='path to sd model')
parser.add_argument('--steps', type=int, default=50, help='diffusion steps')
parser.add_argument('--s', type=float, default=0, help='init img strength')
parser.add_argument('--g', type=float, default=7.5, help='guidance level')
parser.add_argument('--dir', type=str, default="out", help='base directory for storing images')
parser.add_argument('--name', type=str, default="test", help='basename for storing images')
parser.add_argument('--imageSize', type=int, default=512, help='image size')
parser.add_argument('--h', type=int, default=0, help='image height')
parser.add_argument('--w', type=int, default=0, help='image width')
parser.add_argument('--slices', type=int, default=2, help='attention slices')
parser.add_argument('--seed', type=int, default=0, help='manual seed')
opt = parser.parse_args()
device = "cuda"
if opt.h == 0:
opt.h = opt.imageSize
if opt.w == 0:
opt.w = opt.imageSize
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0,3,1,2)
image = torch.from_numpy(image)
print(image.shape)
return 2.0 * image - 1.0
def numpy_to_pil(images):
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
to_tensor_tfm = transforms.ToTensor()
def pil_to_latent(input_im):
with torch.no_grad():
latent = vae.encode(to_tensor_tfm(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
return 0.18215 * latent.mode() # or .mean or .sample
name = opt.name
steps = opt.steps
bs = 1
text = opt.text
# parse prompt with weights
plist = []
wlist = []
wsum = 0
if "/" in text: # split into subprompts
parts = text.split("/")
print(parts)
for p in parts:
ps = p.split(":")
plist.append(ps[0].strip())
w = float(ps[1])
wlist.append(w)
wsum += w
for i in range(0, len(wlist)):
wlist[i] = wlist[i] / wsum
else: # have only a a single prompt
plist = [text]
wlist = [1]
# load model(s)
pretrained_model_name_or_path = opt.model #"./textual_inversion_set2"
use_auth_token = False
vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path, subfolder="vae", use_auth_token=use_auth_token
)
vae.eval()
vae.cuda()
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path, subfolder="unet", use_auth_token=use_auth_token
)
unet.eval()
unet.cuda()
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", use_auth_token=use_auth_token) #.cuda()
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=use_auth_token).cuda()
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
slice_size = unet.config.attention_head_dim // opt.slices
unet.set_attention_slice(slice_size)
num_inference_steps = opt.steps #500
scheduler.set_timesteps(num_inference_steps)
# preprocess init image to latents if any
# otherwise initialize random latent
if opt.seed != 0:
seed = opt.seed
else:
seed = torch.Generator().seed()
generator = torch.Generator(device=device).manual_seed(seed)
if opt.image != "":
init_image = Image.open(opt.image).convert("RGB")
init_image = init_image.resize((opt.w, opt.h))
init_image = preprocess(init_image)
# encode the init image into latents and scale the latents
with torch.no_grad():
init_latent_dist = vae.encode(init_image.to(device)).latent_dist
init_latents = init_latent_dist.sample()
init_latents = 0.18215 * init_latents
init_latents = torch.cat([init_latents]*bs)
# get the original timestep using init_timestep
offset = scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * opt.s) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = torch.tensor(
[num_inference_steps - init_timestep] * bs, device=device, #dtype=torch.long
)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, device=device, generator = generator)
latents = scheduler.add_noise(init_latents, noise, timesteps).to(device).detach()
else:
latents = torch.randn(
(bs, unet.in_channels, opt.h // 8, opt.w // 8), generator = generator, device = device
)
latents = latents * scheduler.sigmas[0]
# discard encoder to save ram
del vae.encoder
# encode subprompts into a weighted embedding
text_input = tokenizer(plist, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
#pn = text_embeddings.shape[0]
tembs = torch.zeros_like(text_embeddings)[0].unsqueeze(0)
i = 0
for temb in text_embeddings:
tembs = tembs + wlist[i] * temb
i += 1
text_embeddings = tembs.detach()
# encode empty prompt
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""] * bs, padding="max_length", max_length=max_length, return_tensors="pt"
)
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
eta = 0
accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
latents = latents.to(device)
j = 0
guidance_scale = opt.g
if opt.image != "":
t_start = max(num_inference_steps - init_timestep + offset, 0)
else:
t_start = 0
with torch.no_grad():
for i, t in tqdm(enumerate(scheduler.timesteps[t_start:])):
t_index = t_start + i
latent_model_input = torch.cat([latents] * 2)
sigma = scheduler.sigmas[t_index]
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = latent_model_input.to(unet.dtype)
t = t.to(unet.dtype)
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.detach()
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample.detach()
j += 1
del unet
latents = 1 / 0.18215 * latents.detach()
with autocast("cuda"):
image = vae.decode(latents.to(vae.dtype)).sample
print(latents.shape, image.shape)
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).detach().numpy()
image = numpy_to_pil(image)[0]
image.save(opt.dir+os.sep+name +".png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment