Created
October 10, 2022 15:29
-
-
Save htoyryla/5ae731331bc703c8293302312c9f0093 to your computer and use it in GitHub Desktop.
Stable diffusion text2image assigning combined embeddings to each UNet block
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# stable diffusion tool | |
# @htoyryla github twitter instagram | |
# requires diffusers 0.3.0 and a trained model | |
# relies heavily on code from https://github.com/huggingface/diffusers | |
# neurokuvatreenit, stable diffusion, example 1b using LDM scheduler and 1c saving image at each iteration | |
# 1d: multiple subprompts with weights | |
# 1e: estimate final result at each iteration | |
# 1e2: updated code for diffusers 0.4.0 | |
''' '' | |
1mp3: experimental, | |
requires modified Unet from https://gist.github.com/htoyryla/dc8c12e3c2bc3543dc5679d56e30c532 | |
For assigning weighted subprompts to blocks, use prompt like | |
"something here first:10 / more stuff here:15 / still something important:35 ; 123 | something else first:10 / some more stuff here:15 ; 0456 " | |
where | |
| embedding separator | |
/ subprompt separator | |
:70 relative weight of subprompt | |
;012 assign this embedding to blocks 0,1, and 2 | |
''' | |
import torch | |
from torchvision.utils import save_image | |
from torchvision import transforms | |
import torch.nn.functional as F | |
from torch import autocast | |
import PIL | |
from PIL import Image | |
import numpy as np | |
from tqdm import tqdm | |
import os | |
import argparse | |
import inspect | |
import sys | |
# we don't use the readymade pipelines so we need to import the modules for VAE, UNET, scheduler and CLIP | |
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler | |
from transformers import CLIPTextModel, CLIPTokenizer | |
# parse omput params | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--text', type=str, default="", help='text prompt') | |
parser.add_argument('--model', type=str, default="./stable-diffusion-v1-4", help='path to sd model') | |
parser.add_argument('--steps', type=int, default=50, help='diffusion steps') | |
parser.add_argument('--g', type=float, default=7.5, help='guidance level') | |
parser.add_argument('--dir', type=str, default="out1", help='base directory for storing images') | |
parser.add_argument('--name', type=str, default="test", help='basename for storing images') | |
parser.add_argument('--imageSize', type=int, default=512, help='image size') | |
parser.add_argument('--h', type=int, default=0, help='image height') | |
parser.add_argument('--w', type=int, default=0, help='image width') | |
parser.add_argument('--slices', type=int, default=2, help='attention slices') | |
parser.add_argument('--seed', type=int, default=0, help='manual seed') | |
parser.add_argument('--saveiters', action="store_true", help='save intermediate images') | |
parser.add_argument('--log', type=str, default="sdruns.log", help='path to log file') | |
raw_args = " ".join(sys.argv) | |
opt = parser.parse_args() | |
# settings | |
num_inference_steps = opt.steps #500 | |
device = "cuda" | |
if opt.h == 0: | |
opt.h = opt.imageSize | |
if opt.w == 0: | |
opt.w = opt.imageSize | |
name = opt.name | |
steps = opt.steps | |
bs = 1 | |
text = opt.text | |
guidance_scale = opt.g | |
def parse_inner(text): | |
# prepare prompt: split into subprompts and their weights | |
plist = [] # a list for subprompts | |
wlist = [] # a list for their weights | |
wsum = 0 | |
#separate assignments first | |
parts = text.split(";") | |
assigns = parts[1].strip() | |
text = parts[0] | |
parts = text.split("/") # split into subprompts at each / | |
print(parts) | |
# separate text and weight for each subprompt | |
for p in parts: | |
ps = p.split(":") | |
plist.append(ps[0].strip()) | |
w = float(ps[1]) | |
wlist.append(w) | |
wsum += w | |
# normalize weights | |
for i in range(0, len(wlist)): | |
wlist[i] = wlist[i] / wsum | |
return plist, wlist, assigns | |
def parse_outer(text): | |
tlist = text.split("|") | |
return tlist | |
# utility methods | |
def numpy_to_pil(images): | |
if images.ndim == 3: | |
images = images[None, ...] | |
images = (images * 255).round().astype("uint8") | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
to_tensor_tfm = transforms.ToTensor() | |
def pil_to_latent(input_im): | |
with torch.no_grad(): | |
latent = vae.encode(to_tensor_tfm(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling | |
return 0.18215 * latent.mode() # or .mean or .sample | |
# load model(s) | |
pretrained_model_name_or_path = opt.model #"./textual_inversion_set2" | |
use_auth_token = True | |
# VAE the imagemaker | |
print("loadng VAE...") | |
vae = AutoencoderKL.from_pretrained( | |
pretrained_model_name_or_path, subfolder="vae", use_auth_token=use_auth_token | |
) | |
vae.eval() | |
vae.cuda() | |
del vae.encoder | |
#UNET the denoiserprint("loadng VAE...") | |
print("loadng UNET...") | |
unet = UNet2DConditionModel.from_pretrained( | |
pretrained_model_name_or_path, subfolder="unet", use_auth_token=use_auth_token | |
) | |
unet.eval() | |
unet.cuda() | |
# text encoder | |
print("loadng CLIP...") | |
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", use_auth_token=use_auth_token) #.cuda() | |
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=use_auth_token).cuda() | |
# Scheduler the noise manager | |
print("setting up scheduler...") | |
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) | |
scheduler.set_timesteps(num_inference_steps) | |
eta = 0 | |
accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys()) | |
extra_step_kwargs = {} | |
if accepts_eta: | |
extra_step_kwargs["eta"] = eta | |
# attention slicing to save ram | |
slice_size = unet.config.attention_head_dim // opt.slices | |
unet.set_attention_slice(slice_size) | |
# set up random number gen | |
if opt.seed != 0: | |
seed = opt.seed | |
else: | |
seed = torch.Generator().seed() | |
generator = torch.Generator(device=device).manual_seed(seed) | |
print("Seed:" + str(generator.initial_seed())) | |
# save command into log | |
if opt.log != "": | |
if "--seed" not in raw_args: | |
raw_args += " --seed "+str(seed) # add current seed | |
raw_args = "python "+raw_args+"\n" | |
with open(os.path.join(opt.dir, opt.log), "a+") as text_file: | |
text_file.write(raw_args) | |
# initialize latents randomly | |
latents = torch.randn( | |
(bs, unet.in_channels, opt.h // 8, opt.w // 8), generator = generator, device = device | |
) | |
latents = latents * scheduler.init_noise_sigma | |
#latents = latents * scheduler.sigmas[0] | |
print("Latents shape",latents.shape) | |
latents = latents.to(device) | |
# text to tokens | |
# encode empty prompt | |
#tokens_length = text_tokens.input_ids.shape[-1] | |
tokens_length = tokenizer.model_max_length | |
uncond_tokens = tokenizer( | |
[""] * bs, padding="max_length", max_length=tokens_length, return_tensors="pt" | |
) | |
with torch.no_grad(): | |
uncond_embeddings = text_encoder(uncond_tokens.input_ids.to(device))[0] | |
empty_emb = torch.cat([uncond_embeddings, uncond_embeddings]) | |
# make placeholder for 7 embeddings and fill with uncond embeddings | |
prep_embeddings = empty_emb.unsqueeze(0).repeat(7,1,1,1) | |
# parse text into n prompt sets | |
tlist = parse_outer(text) | |
for t in tlist: | |
plist, wlist, assign = parse_inner(t) | |
# we process a list of all subprompts at the same time | |
text_tokens = tokenizer(plist, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
print("Tokens shape",text_tokens.input_ids.shape) | |
# tokens to embedding | |
with torch.no_grad(): | |
text_embeddings = text_encoder(text_tokens.input_ids.to(device))[0] | |
print("Text embeddings shape ",text_embeddings.shape) | |
# now we have embeddings of all subprompts, then if there is more than one, calculate their weighted average | |
pn = text_embeddings.shape[0] | |
if pn > 1: | |
tembs = torch.zeros_like(text_embeddings)[0].unsqueeze(0) | |
i = 0 | |
for temb in text_embeddings: | |
tembs = tembs + wlist[i] * temb | |
i += 1 | |
comb_embeddings = tembs.detach() | |
print("Text embeddings shape after combining",comb_embeddings.shape) | |
# assign to blocks | |
for emb in comb_embeddings: | |
temb = torch.cat([uncond_embeddings, emb.unsqueeze(0)]) | |
for c in assign: | |
print(c) | |
prep_embeddings[int(c)] = temb | |
print(prep_embeddings.shape) | |
# save some ram | |
del text_encoder | |
torch.cuda.empty_cache() | |
j = 0 | |
# start diffusing | |
with torch.no_grad(): | |
for i, t in tqdm(enumerate(scheduler.timesteps)): | |
# prepare current latent for UNET | |
latent_model_input = torch.cat([latents] * 2) | |
# adjust latents according to sigmas (current noise level) | |
# sigma no longer needed here in 0.4.0 but we use it later | |
sigma = scheduler.sigmas[i] | |
latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
# estimate the noise | |
noise_pred = unet(latent_model_input, t, encoder_hidden_states=prep_embeddings).sample.detach() | |
# adjust noise estimate for text guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
# estimate denoised latent | |
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample.detach() | |
if opt.saveiters: | |
# estimate final image from current state | |
los = latents - sigma * noise_pred | |
# save an image from current latents | |
lats_ = 1 / 0.18215 * los.detach() | |
image = vae.decode(lats_.to(vae.dtype)).sample | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.cpu().permute(0, 2, 3, 1).detach().numpy() | |
image = numpy_to_pil(image)[0] | |
image.save(opt.dir+os.sep+name +"-t"+str(j)+".png") | |
j += 1 | |
del unet | |
# now we have final latent, let's decode the image and save it | |
latents = 1 / 0.18215 * latents.detach() | |
image = vae.decode(latents.to(vae.dtype)).sample | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.cpu().permute(0, 2, 3, 1).detach().numpy() | |
image = numpy_to_pil(image)[0] | |
image.save(opt.dir+os.sep+name +".png") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment