Last active
October 10, 2022 08:16
-
-
Save htoyryla/21b93dea22a72dc207fededcf939d770 to your computer and use it in GitHub Desktop.
Stable diffusion text2image assigning different prompts to each UNet block EXPERIMENTAL
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# stable diffusion tool | |
# @htoyryla github twitter instagram | |
# requires diffusers 0.4.0 and a trained model | |
# relies heavily on code from https://github.com/huggingface/diffusers | |
# neurokuvatreenit, stable diffusion, example 1b using LDM scheduler and 1c saving image at each iteration | |
# 1d: multiple subprompts with weights | |
# 1e: estimate final result at each iteration | |
# 1e2: updated code for diffusers 0.4.0 | |
''' '' | |
1mp2: experimental, | |
requires modified Unet from https://gist.github.com/htoyryla/dc8c12e3c2bc3543dc5679d56e30c532 | |
assign subprompts to different layers using | |
--text "something here first:012 / more stuff here:3 / still something important:456" | |
''' | |
import torch | |
from torchvision.utils import save_image | |
from torchvision import transforms | |
import torch.nn.functional as F | |
from torch import autocast | |
import PIL | |
from PIL import Image | |
import numpy as np | |
from tqdm import tqdm | |
import os | |
import argparse | |
import inspect | |
import sys | |
# we don't use the readymade pipelines so we need to import the modules for VAE, UNET, scheduler and CLIP | |
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler | |
from transformers import CLIPTextModel, CLIPTokenizer | |
# parse omput params | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--text', type=str, default="", help='text prompt') | |
parser.add_argument('--model', type=str, default="./stable-diffusion-v1-4", help='path to sd model') | |
parser.add_argument('--steps', type=int, default=50, help='diffusion steps') | |
parser.add_argument('--g', type=float, default=7.5, help='guidance level') | |
parser.add_argument('--dir', type=str, default="out1", help='base directory for storing images') | |
parser.add_argument('--name', type=str, default="test", help='basename for storing images') | |
parser.add_argument('--imageSize', type=int, default=512, help='image size') | |
parser.add_argument('--h', type=int, default=0, help='image height') | |
parser.add_argument('--w', type=int, default=0, help='image width') | |
parser.add_argument('--slices', type=int, default=2, help='attention slices') | |
parser.add_argument('--seed', type=int, default=0, help='manual seed') | |
parser.add_argument('--saveiters', action="store_true", help='save intermediate images') | |
parser.add_argument('--log', type=str, default="sdruns.log", help='path to log file') | |
raw_args = " ".join(sys.argv) | |
opt = parser.parse_args() | |
# settings | |
num_inference_steps = opt.steps #500 | |
device = "cuda" | |
if opt.h == 0: | |
opt.h = opt.imageSize | |
if opt.w == 0: | |
opt.w = opt.imageSize | |
name = opt.name | |
steps = opt.steps | |
bs = 1 | |
text = opt.text | |
guidance_scale = opt.g | |
# prepare prompt: split into subprompts and their assignments | |
plist = [] # a list for subprompts | |
wlist = [] # a list for their layers | |
parts = text.split("/") # split into subprompts at each / | |
print(parts) | |
# separate text and assignment for each subprompt | |
for p in parts: | |
ps = p.split(":") | |
plist.append(ps[0].strip()) | |
wlist.append(ps[1].strip()) | |
# utility methods | |
def numpy_to_pil(images): | |
if images.ndim == 3: | |
images = images[None, ...] | |
images = (images * 255).round().astype("uint8") | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
to_tensor_tfm = transforms.ToTensor() | |
def pil_to_latent(input_im): | |
with torch.no_grad(): | |
latent = vae.encode(to_tensor_tfm(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling | |
return 0.18215 * latent.mode() # or .mean or .sample | |
# load model(s) | |
pretrained_model_name_or_path = opt.model #"./textual_inversion_set2" | |
use_auth_token = True | |
# VAE the imagemaker | |
print("loadng VAE...") | |
vae = AutoencoderKL.from_pretrained( | |
pretrained_model_name_or_path, subfolder="vae", use_auth_token=use_auth_token | |
) | |
vae.eval() | |
vae.cuda() | |
del vae.encoder | |
#UNET the denoiser | |
print("loadng UNET...") | |
unet = UNet2DConditionModel.from_pretrained( | |
pretrained_model_name_or_path, subfolder="unet", use_auth_token=use_auth_token | |
) | |
unet.eval() | |
unet.cuda() | |
# text encoder | |
print("loadng CLIP...") | |
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", use_auth_token=use_auth_token) #.cuda() | |
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=use_auth_token).cuda() | |
# Scheduler the noise manager | |
print("setting up scheduler...") | |
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) | |
scheduler.set_timesteps(num_inference_steps) | |
eta = 0 | |
accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys()) | |
extra_step_kwargs = {} | |
if accepts_eta: | |
extra_step_kwargs["eta"] = eta | |
# attention slicing to save ram | |
slice_size = unet.config.attention_head_dim // opt.slices | |
unet.set_attention_slice(slice_size) | |
# set up random number gen | |
if opt.seed != 0: | |
seed = opt.seed | |
else: | |
seed = torch.Generator().seed() | |
generator = torch.Generator(device=device).manual_seed(seed) | |
print("Seed:" + str(generator.initial_seed())) | |
# save command into log | |
if opt.log != "": | |
if "--seed" not in raw_args: | |
raw_args += " --seed "+str(seed) # add current seed | |
raw_args = "python "+raw_args+"\n" | |
with open(os.path.join(opt.dir, opt.log), "a+") as text_file: | |
text_file.write(raw_args) | |
# initialize latents randomly | |
latents = torch.randn( | |
(bs, unet.in_channels, opt.h // 8, opt.w // 8), generator = generator, device = device | |
) | |
latents = latents * scheduler.init_noise_sigma | |
#latents = latents * scheduler.sigmas[0] | |
print("Latents shape",latents.shape) | |
latents = latents.to(device) | |
# text to tokens | |
# we process a list of all subprompts at the same time | |
text_tokens = tokenizer(plist, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
print("Tokens shape",text_tokens.input_ids.shape) | |
# tokens to embedding | |
with torch.no_grad(): | |
text_embeddings = text_encoder(text_tokens.input_ids.to(device))[0] | |
print("Text embeddings shape ",text_embeddings.shape) | |
# encode empty prompt | |
tokens_length = text_tokens.input_ids.shape[-1] | |
uncond_tokens = tokenizer( | |
[""] * bs, padding="max_length", max_length=tokens_length, return_tensors="pt" | |
) | |
with torch.no_grad(): | |
uncond_embeddings = text_encoder(uncond_tokens.input_ids.to(device))[0] | |
# now we process the subprompts into a tensor containing an embedding for each block in the Unet | |
tembs = zip(text_embeddings, wlist) | |
empty_emb = torch.cat([uncond_embeddings, uncond_embeddings]) | |
prep_embeddings = empty_emb.unsqueeze(0).repeat(7,1,1,1) | |
#loop through subprompts | |
for emb in tembs: | |
temb = torch.cat([uncond_embeddings, emb[0].unsqueeze(0)]) | |
for c in emb[1]: | |
prep_embeddings[int(c)] = temb | |
print(prep_embeddings.shape) | |
# save some ram | |
del text_encoder | |
torch.cuda.empty_cache() | |
j = 0 | |
# start diffusing | |
with torch.no_grad(): | |
for i, t in tqdm(enumerate(scheduler.timesteps)): | |
# prepare current latent for UNET | |
latent_model_input = torch.cat([latents] * 2) | |
# adjust latents according to sigmas (current noise level) | |
# sigma no longer needed here in 0.4.0 but we use it later | |
sigma = scheduler.sigmas[i] | |
latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
# estimate the noise | |
noise_pred = unet(latent_model_input, t, encoder_hidden_states=prep_embeddings).sample.detach() | |
# adjust noise estimate for text guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
# estimate denoised latent | |
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample.detach() | |
if opt.saveiters: | |
# estimate final image from current state | |
los = latents - sigma * noise_pred | |
# save an image from current latents | |
lats_ = 1 / 0.18215 * los.detach() | |
image = vae.decode(lats_.to(vae.dtype)).sample | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.cpu().permute(0, 2, 3, 1).detach().numpy() | |
image = numpy_to_pil(image)[0] | |
image.save(opt.dir+os.sep+name +"-t"+str(j)+".png") | |
j += 1 | |
del unet | |
# now we have final latent, let's decode the image and save it | |
latents = 1 / 0.18215 * latents.detach() | |
image = vae.decode(latents.to(vae.dtype)).sample | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.cpu().permute(0, 2, 3, 1).detach().numpy() | |
image = numpy_to_pil(image)[0] | |
image.save(opt.dir+os.sep+name +".png") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment