Skip to content

Instantly share code, notes, and snippets.

@htoyryla
Last active October 10, 2022 08:16
Show Gist options
  • Save htoyryla/21b93dea22a72dc207fededcf939d770 to your computer and use it in GitHub Desktop.
Save htoyryla/21b93dea22a72dc207fededcf939d770 to your computer and use it in GitHub Desktop.
Stable diffusion text2image assigning different prompts to each UNet block EXPERIMENTAL
# stable diffusion tool
# @htoyryla github twitter instagram
# requires diffusers 0.4.0 and a trained model
# relies heavily on code from https://github.com/huggingface/diffusers
# neurokuvatreenit, stable diffusion, example 1b using LDM scheduler and 1c saving image at each iteration
# 1d: multiple subprompts with weights
# 1e: estimate final result at each iteration
# 1e2: updated code for diffusers 0.4.0
''' ''
1mp2: experimental,
requires modified Unet from https://gist.github.com/htoyryla/dc8c12e3c2bc3543dc5679d56e30c532
assign subprompts to different layers using
--text "something here first:012 / more stuff here:3 / still something important:456"
'''
import torch
from torchvision.utils import save_image
from torchvision import transforms
import torch.nn.functional as F
from torch import autocast
import PIL
from PIL import Image
import numpy as np
from tqdm import tqdm
import os
import argparse
import inspect
import sys
# we don't use the readymade pipelines so we need to import the modules for VAE, UNET, scheduler and CLIP
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
from transformers import CLIPTextModel, CLIPTokenizer
# parse omput params
parser = argparse.ArgumentParser()
parser.add_argument('--text', type=str, default="", help='text prompt')
parser.add_argument('--model', type=str, default="./stable-diffusion-v1-4", help='path to sd model')
parser.add_argument('--steps', type=int, default=50, help='diffusion steps')
parser.add_argument('--g', type=float, default=7.5, help='guidance level')
parser.add_argument('--dir', type=str, default="out1", help='base directory for storing images')
parser.add_argument('--name', type=str, default="test", help='basename for storing images')
parser.add_argument('--imageSize', type=int, default=512, help='image size')
parser.add_argument('--h', type=int, default=0, help='image height')
parser.add_argument('--w', type=int, default=0, help='image width')
parser.add_argument('--slices', type=int, default=2, help='attention slices')
parser.add_argument('--seed', type=int, default=0, help='manual seed')
parser.add_argument('--saveiters', action="store_true", help='save intermediate images')
parser.add_argument('--log', type=str, default="sdruns.log", help='path to log file')
raw_args = " ".join(sys.argv)
opt = parser.parse_args()
# settings
num_inference_steps = opt.steps #500
device = "cuda"
if opt.h == 0:
opt.h = opt.imageSize
if opt.w == 0:
opt.w = opt.imageSize
name = opt.name
steps = opt.steps
bs = 1
text = opt.text
guidance_scale = opt.g
# prepare prompt: split into subprompts and their assignments
plist = [] # a list for subprompts
wlist = [] # a list for their layers
parts = text.split("/") # split into subprompts at each /
print(parts)
# separate text and assignment for each subprompt
for p in parts:
ps = p.split(":")
plist.append(ps[0].strip())
wlist.append(ps[1].strip())
# utility methods
def numpy_to_pil(images):
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
to_tensor_tfm = transforms.ToTensor()
def pil_to_latent(input_im):
with torch.no_grad():
latent = vae.encode(to_tensor_tfm(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
return 0.18215 * latent.mode() # or .mean or .sample
# load model(s)
pretrained_model_name_or_path = opt.model #"./textual_inversion_set2"
use_auth_token = True
# VAE the imagemaker
print("loadng VAE...")
vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path, subfolder="vae", use_auth_token=use_auth_token
)
vae.eval()
vae.cuda()
del vae.encoder
#UNET the denoiser
print("loadng UNET...")
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path, subfolder="unet", use_auth_token=use_auth_token
)
unet.eval()
unet.cuda()
# text encoder
print("loadng CLIP...")
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", use_auth_token=use_auth_token) #.cuda()
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=use_auth_token).cuda()
# Scheduler the noise manager
print("setting up scheduler...")
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
scheduler.set_timesteps(num_inference_steps)
eta = 0
accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# attention slicing to save ram
slice_size = unet.config.attention_head_dim // opt.slices
unet.set_attention_slice(slice_size)
# set up random number gen
if opt.seed != 0:
seed = opt.seed
else:
seed = torch.Generator().seed()
generator = torch.Generator(device=device).manual_seed(seed)
print("Seed:" + str(generator.initial_seed()))
# save command into log
if opt.log != "":
if "--seed" not in raw_args:
raw_args += " --seed "+str(seed) # add current seed
raw_args = "python "+raw_args+"\n"
with open(os.path.join(opt.dir, opt.log), "a+") as text_file:
text_file.write(raw_args)
# initialize latents randomly
latents = torch.randn(
(bs, unet.in_channels, opt.h // 8, opt.w // 8), generator = generator, device = device
)
latents = latents * scheduler.init_noise_sigma
#latents = latents * scheduler.sigmas[0]
print("Latents shape",latents.shape)
latents = latents.to(device)
# text to tokens
# we process a list of all subprompts at the same time
text_tokens = tokenizer(plist, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
print("Tokens shape",text_tokens.input_ids.shape)
# tokens to embedding
with torch.no_grad():
text_embeddings = text_encoder(text_tokens.input_ids.to(device))[0]
print("Text embeddings shape ",text_embeddings.shape)
# encode empty prompt
tokens_length = text_tokens.input_ids.shape[-1]
uncond_tokens = tokenizer(
[""] * bs, padding="max_length", max_length=tokens_length, return_tensors="pt"
)
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_tokens.input_ids.to(device))[0]
# now we process the subprompts into a tensor containing an embedding for each block in the Unet
tembs = zip(text_embeddings, wlist)
empty_emb = torch.cat([uncond_embeddings, uncond_embeddings])
prep_embeddings = empty_emb.unsqueeze(0).repeat(7,1,1,1)
#loop through subprompts
for emb in tembs:
temb = torch.cat([uncond_embeddings, emb[0].unsqueeze(0)])
for c in emb[1]:
prep_embeddings[int(c)] = temb
print(prep_embeddings.shape)
# save some ram
del text_encoder
torch.cuda.empty_cache()
j = 0
# start diffusing
with torch.no_grad():
for i, t in tqdm(enumerate(scheduler.timesteps)):
# prepare current latent for UNET
latent_model_input = torch.cat([latents] * 2)
# adjust latents according to sigmas (current noise level)
# sigma no longer needed here in 0.4.0 but we use it later
sigma = scheduler.sigmas[i]
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# estimate the noise
noise_pred = unet(latent_model_input, t, encoder_hidden_states=prep_embeddings).sample.detach()
# adjust noise estimate for text guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# estimate denoised latent
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample.detach()
if opt.saveiters:
# estimate final image from current state
los = latents - sigma * noise_pred
# save an image from current latents
lats_ = 1 / 0.18215 * los.detach()
image = vae.decode(lats_.to(vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).detach().numpy()
image = numpy_to_pil(image)[0]
image.save(opt.dir+os.sep+name +"-t"+str(j)+".png")
j += 1
del unet
# now we have final latent, let's decode the image and save it
latents = 1 / 0.18215 * latents.detach()
image = vae.decode(latents.to(vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).detach().numpy()
image = numpy_to_pil(image)[0]
image.save(opt.dir+os.sep+name +".png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment