Last active
July 29, 2023 20:01
-
-
Save trygvebw/97681a7e214a6fbc01b68996817987c8 to your computer and use it in GitHub Desktop.
My Stable Diffusion image generation function. Abbreviated – will not run because of a few missing utility functions and classes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def normalize_latent(x, max_val, quantile_val): | |
x = x.detach().clone() | |
for i in range(x.shape[0]): | |
if x[[i], :].std() > 1.0: | |
x[[i], :] = x[[i], :] / x[[i], :].std() | |
s = torch.quantile(torch.abs(x[[i], :]), quantile_val) | |
s = torch.maximum(s, torch.ones_like(s) * max_val) | |
x[[i], :] = x[[i], :] / (s / max_val) | |
return x | |
class CFGDenoiser(nn.Module): | |
def __init__(self, model, steps, rescale, rescaling_coeff): | |
super().__init__() | |
self.inner_model = model | |
self.total_steps = steps | |
self.current_step = 0 | |
self.rescale = rescale | |
self.rescaling_coeff = rescaling_coeff | |
def forward(self, x, sigma, uncond, cond, cond_scale): | |
self.current_step += 1 | |
x_in = torch.cat([x] * 2) | |
sigma_in = torch.cat([sigma] * 2) | |
cond_in = torch.cat([uncond, cond]) | |
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) | |
x_out = uncond + (cond - uncond) * cond_scale | |
if self.rescale: | |
x_out = normalize_latent(x_out, self.rescaling_coeff, 0.975) | |
return x_out | |
K_SAMPLERS = { | |
'k_lms': K.sampling.sample_lms, | |
'k_dpm_2': K.sampling.sample_dpm_2, | |
'k_dpm_2_ancestral': K.sampling.sample_dpm_2_ancestral, | |
'k_heun': K.sampling.sample_heun, | |
'k_euler': K.sampling.sample_euler, | |
'k_euler_inpaint': lambda *args, **kwargs: sample_euler_inpaint(*args, **kwargs), | |
'k_euler_ancestral': K.sampling.sample_euler_ancestral, | |
} | |
def _generate(model, settings, init_img, prompt, guidance, mask_img, verbose, silent): | |
opt = settings | |
seed = do_seed(opt.seed, print_seed=not silent) | |
half_precision = model.half_precision == True | |
batch_size = opt.n_samples | |
n_cols = opt.n_cols if opt.n_cols > 0 else batch_size | |
if opt.sampler is None: | |
opt.sampler = 'k_euler_inpaint' if mask_img is not None else 'k_euler' | |
accelerator = accelerate.Accelerator() | |
device = accelerator.device | |
# Instantiate k-diffusion denoiser | |
denoiser = K.external.CompVisDenoiser(model) | |
# Convert to long prompt format | |
data = [] | |
if isinstance(prompt, str): | |
data.append((prompt, 1.0)) | |
elif prompt is not None: | |
data = [*prompt] | |
# For the output filename | |
prompt_text_concat = ', '.join([prompt for prompt, weight in data]) if len(data) > 0 else 'No prompt' | |
# Shape (not currently in use for img2img) | |
shape = [opt.C, opt.height // opt.f, opt.width // opt.f] | |
# In contrast to the official SD script, 'strength' modifies the denoising process | |
# for txt2img as well as for img2img | |
t_enc = int(opt.strength * opt.ddim_steps) | |
if init_img is None: | |
# txt2img | |
init_latent = torch.zeros([batch_size, *shape], device=device) | |
else: | |
# img2img | |
init_latent = pil_img_to_latent(init_img, batch_size=batch_size, device=device, half=half_precision) | |
latent_mask = None | |
if mask_img is not None and init_img is not None: | |
# Inpainting | |
latent_mask = preprocess_mask(mask_img, inverse=opt.invert_mask).to(device) | |
latent_mask = torch.cat([latent_mask] * batch_size) | |
# k-diffusion sampler sigmas (continuous noise levels) | |
if opt.schedule == 'default': | |
sigmas = denoiser.get_sigmas(opt.ddim_steps) | |
elif opt.schedule == 'karras': | |
model_sigmas = (((1 - model.alphas_cumprod) / model.alphas_cumprod) ** 0.5).to('cpu') | |
sigmas = K.sampling.get_sigmas_karras(opt.ddim_steps, model_sigmas[0], model_sigmas[-1], rho=7.0, device='cpu').cuda() | |
else: | |
raise InvalidSettingsError('Value of setting "schedule" must be either "default" or "karras"') | |
sigmas_actual = sigmas[-(t_enc + 1):] | |
# We multiply the noise by sigmas_actual[0] by default, but | |
# have the possibility of adding more/less noise than the | |
# default using the 'sigma_offset_steps' setting or | |
# the 'extra_noise' setting. | |
noising_sigma = sigmas_actual[-min(len(sigmas_actual), t_enc + 1 - opt.sigma_offset_steps)] + opt.extra_noise | |
images = [] | |
precision_scope = autocast if opt.precision == "autocast" else nullcontext | |
with torch.no_grad(): | |
with precision_scope('cuda'): | |
with model.ema_scope(): | |
tic = time.time() | |
for n in range(opt.n_iter): | |
if not silent: | |
print(f'Iteration {n+1} of {opt.n_iter}') | |
do_seed(seed if n == 0 else None, print_seed=n > 0) | |
# Unconditional guidance | |
uc = model.get_learned_conditioning(batch_size * [""]) | |
# Conditional guidance | |
if guidance is None: | |
# Create guidance tensor based on the | |
# supplied prompt(s) | |
c = torch.zeros_like(uc) | |
weight_sum = sum([weight for prompt, weight in data]) | |
for prompt, weight in data: | |
c += (model.get_learned_conditioning(batch_size * [prompt]) * weight) / (weight_sum if opt.normalize_prompt_weights else 1) | |
else: | |
c = torch.cat(batch_size * [opt.preset_cond_guidance]) | |
# Adding noise to initial image (in the case of txt2img, the initial image | |
# is all-zeros before this) | |
if opt.rescale: | |
init_latent = normalize_latent(init_latent, opt.latent_rescaling_coeff, 0.975) | |
x = init_latent + torch.randn_like(init_latent) * noising_sigma | |
denoiser_cfg = CFGDenoiser(denoiser, t_enc, | |
rescale=opt.rescale, rescaling_coeff=opt.latent_rescaling_coeff) | |
extra_args = { | |
'cond_scale': opt.scale, | |
'cond': c, | |
'uncond': uc | |
} | |
extra_sampler_args = {} if mask_img is None else { | |
'mask': latent_mask, | |
'latent_img_orig': init_latent | |
} | |
samples = K_SAMPLERS[opt.sampler]( | |
denoiser_cfg, x, sigmas_actual, | |
extra_args=extra_args, | |
disable=(not accelerator.is_main_process) or (), | |
**extra_sampler_args) | |
if accelerator.is_main_process: | |
for i in range(batch_size): | |
decoded_img = model.decode_first_stage(samples[i].unsqueeze(0)) | |
images.append(torch_img_to_pil(decoded_img, increase_contrast=opt.rescale)) | |
del samples | |
del decoded_img | |
if accelerator.is_main_process and opt.grid: | |
images = [make_pil_grid(images, n_cols)] + images | |
if not silent: | |
toc = time.time() | |
time_diff = toc - tic | |
print(f'Image generation took {time_diff} seconds.') | |
if opt.output_folder is not None: | |
current_timestamp = datetime.datetime.now().isoformat(timespec='seconds') | |
generation_folder = os.path.join(opt.output_folder, f'{current_timestamp}__{prompt_text_concat[:25]}') | |
os.makedirs(generation_folder, exist_ok=True) | |
for i, img in enumerate(images): | |
if i == 0 and opt.grid: | |
filename = os.path.join(generation_folder, 'grid.png') | |
else: | |
filename = os.path.join(generation_folder, f'image{i}.png') | |
img.save(filename) | |
return images | |
def generate(model, settings, init_img=None, prompt=None, guidance=None, mask_img=None, verbose=False, silent=False): | |
if isinstance(init_img, str): | |
init_img = load_pil_img(init_img) | |
if isinstance(mask_img, str): | |
mask_img = load_pil_img(mask_img, nn_resampling=True) | |
if prompt is None and guidance is None: | |
raise InvalidArgumentsError('One of "prompt" and "guidance" must be set') | |
if prompt is not None and guidance is not None: | |
raise InvalidArgumentsError('Only one of "prompt" and "guidance" can be set') | |
try: | |
return _generate(model, settings, init_img, prompt, guidance, mask_img, verbose, silent) | |
finally: | |
collect_and_empty() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment