Last active
September 21, 2022 07:18
-
-
Save htoyryla/698881b7098514b58de774a3d1562e6c to your computer and use it in GitHub Desktop.
Stable diffusion tool, image synthesis, img2img and prompt weights
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# stable diffusion tool | |
# @htoyryla github twitter instagram | |
# requires diffusers 0.3.0 and a trained model | |
# relies heavily on code from https://github.com/huggingface/diffusers | |
# for prompt weighting use prompts like "a subprompt like this:10 / in the style of that:25" | |
import torch | |
from torchvision.utils import save_image | |
from torchvision import transforms | |
import torch.nn.functional as F | |
from torch import autocast | |
import PIL | |
from PIL import Image | |
import numpy as np | |
from tqdm import tqdm | |
import os | |
import argparse | |
import inspect | |
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler | |
from transformers import CLIPTextModel, CLIPTokenizer | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--text', type=str, default="", help='text prompt') | |
parser.add_argument('--image', type=str, default="", help='path to init image') | |
parser.add_argument('--model', type=str, default="./stable-diffusion-v1-4", help='path to sd model') | |
parser.add_argument('--steps', type=int, default=50, help='diffusion steps') | |
parser.add_argument('--s', type=float, default=0, help='init img strength') | |
parser.add_argument('--g', type=float, default=7.5, help='guidance level') | |
parser.add_argument('--dir', type=str, default="out", help='base directory for storing images') | |
parser.add_argument('--name', type=str, default="test", help='basename for storing images') | |
parser.add_argument('--imageSize', type=int, default=512, help='image size') | |
parser.add_argument('--h', type=int, default=0, help='image height') | |
parser.add_argument('--w', type=int, default=0, help='image width') | |
parser.add_argument('--slices', type=int, default=2, help='attention slices') | |
parser.add_argument('--seed', type=int, default=0, help='manual seed') | |
opt = parser.parse_args() | |
device = "cuda" | |
if opt.h == 0: | |
opt.h = opt.imageSize | |
if opt.w == 0: | |
opt.w = opt.imageSize | |
def preprocess(image): | |
w, h = image.size | |
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 | |
image = image.resize((w, h), resample=PIL.Image.LANCZOS) | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = image[None].transpose(0,3,1,2) | |
image = torch.from_numpy(image) | |
print(image.shape) | |
return 2.0 * image - 1.0 | |
def numpy_to_pil(images): | |
if images.ndim == 3: | |
images = images[None, ...] | |
images = (images * 255).round().astype("uint8") | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
to_tensor_tfm = transforms.ToTensor() | |
def pil_to_latent(input_im): | |
with torch.no_grad(): | |
latent = vae.encode(to_tensor_tfm(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling | |
return 0.18215 * latent.mode() # or .mean or .sample | |
name = opt.name | |
steps = opt.steps | |
bs = 1 | |
text = opt.text | |
# parse prompt with weights | |
plist = [] | |
wlist = [] | |
wsum = 0 | |
if "/" in text: # split into subprompts | |
parts = text.split("/") | |
print(parts) | |
for p in parts: | |
ps = p.split(":") | |
plist.append(ps[0].strip()) | |
w = float(ps[1]) | |
wlist.append(w) | |
wsum += w | |
for i in range(0, len(wlist)): | |
wlist[i] = wlist[i] / wsum | |
else: # have only a a single prompt | |
plist = [text] | |
wlist = [1] | |
# load model(s) | |
pretrained_model_name_or_path = opt.model #"./textual_inversion_set2" | |
use_auth_token = False | |
vae = AutoencoderKL.from_pretrained( | |
pretrained_model_name_or_path, subfolder="vae", use_auth_token=use_auth_token | |
) | |
vae.eval() | |
vae.cuda() | |
unet = UNet2DConditionModel.from_pretrained( | |
pretrained_model_name_or_path, subfolder="unet", use_auth_token=use_auth_token | |
) | |
unet.eval() | |
unet.cuda() | |
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", use_auth_token=use_auth_token) #.cuda() | |
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=use_auth_token).cuda() | |
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) | |
slice_size = unet.config.attention_head_dim // opt.slices | |
unet.set_attention_slice(slice_size) | |
num_inference_steps = opt.steps #500 | |
scheduler.set_timesteps(num_inference_steps) | |
# preprocess init image to latents if any | |
# otherwise initialize random latent | |
if opt.seed != 0: | |
seed = opt.seed | |
else: | |
seed = torch.Generator().seed() | |
generator = torch.Generator(device=device).manual_seed(seed) | |
if opt.image != "": | |
init_image = Image.open(opt.image).convert("RGB") | |
init_image = init_image.resize((opt.w, opt.h)) | |
init_image = preprocess(init_image) | |
# encode the init image into latents and scale the latents | |
with torch.no_grad(): | |
init_latent_dist = vae.encode(init_image.to(device)).latent_dist | |
init_latents = init_latent_dist.sample() | |
init_latents = 0.18215 * init_latents | |
init_latents = torch.cat([init_latents]*bs) | |
# get the original timestep using init_timestep | |
offset = scheduler.config.get("steps_offset", 0) | |
init_timestep = int(num_inference_steps * opt.s) + offset | |
init_timestep = min(init_timestep, num_inference_steps) | |
timesteps = torch.tensor( | |
[num_inference_steps - init_timestep] * bs, device=device, #dtype=torch.long | |
) | |
# add noise to latents using the timesteps | |
noise = torch.randn(init_latents.shape, device=device, generator = generator) | |
latents = scheduler.add_noise(init_latents, noise, timesteps).to(device).detach() | |
else: | |
latents = torch.randn( | |
(bs, unet.in_channels, opt.h // 8, opt.w // 8), generator = generator, device = device | |
) | |
latents = latents * scheduler.sigmas[0] | |
# discard encoder to save ram | |
del vae.encoder | |
# encode subprompts into a weighted embedding | |
text_input = tokenizer(plist, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
with torch.no_grad(): | |
text_embeddings = text_encoder(text_input.input_ids.to(device))[0] | |
#pn = text_embeddings.shape[0] | |
tembs = torch.zeros_like(text_embeddings)[0].unsqueeze(0) | |
i = 0 | |
for temb in text_embeddings: | |
tembs = tembs + wlist[i] * temb | |
i += 1 | |
text_embeddings = tembs.detach() | |
# encode empty prompt | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = tokenizer( | |
[""] * bs, padding="max_length", max_length=max_length, return_tensors="pt" | |
) | |
with torch.no_grad(): | |
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
eta = 0 | |
accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys()) | |
extra_step_kwargs = {} | |
if accepts_eta: | |
extra_step_kwargs["eta"] = eta | |
latents = latents.to(device) | |
j = 0 | |
guidance_scale = opt.g | |
if opt.image != "": | |
t_start = max(num_inference_steps - init_timestep + offset, 0) | |
else: | |
t_start = 0 | |
with torch.no_grad(): | |
for i, t in tqdm(enumerate(scheduler.timesteps[t_start:])): | |
t_index = t_start + i | |
latent_model_input = torch.cat([latents] * 2) | |
sigma = scheduler.sigmas[t_index] | |
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) | |
latent_model_input = latent_model_input.to(unet.dtype) | |
t = t.to(unet.dtype) | |
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.detach() | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
latents = scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample.detach() | |
j += 1 | |
del unet | |
latents = 1 / 0.18215 * latents.detach() | |
with autocast("cuda"): | |
image = vae.decode(latents.to(vae.dtype)).sample | |
print(latents.shape, image.shape) | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.cpu().permute(0, 2, 3, 1).detach().numpy() | |
image = numpy_to_pil(image)[0] | |
image.save(opt.dir+os.sep+name +".png") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment