Created
July 3, 2023 17:55
-
-
Save Clybius/9a4221245d27898f9c6e8c7cf4fc7eff to your computer and use it in GitHub Desktop.
DPM++ 2M SDE Adaptive Sampler | A modified 2M SDE sampler complete with cosine similarity matching with a 2S step, and adaptive second order sampling.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# How to setup in A1111 Stable Diffusion and its various forks | |
1) You can install the sampler by adding it to the bottom of `repositories/k-diffusion/k_diffusion/sampling.py` | |
2) Then within `modules/sd_samplers_kdiffusion.py`, add the following to `samplers_k_diffusion` | |
('DPM++ 2M SDE Adaptive', 'sample_dpmpp_2m_sde_adaptive', ['c_dpmpp_2m_sde_ad'], {"brownian_noise": True, 'scheduler': 'karras', 'discard_next_to_last_sigma': True}), | |
Notes: This has been tested on Vlad's A1111 fork. Step 2 may not be exactly the same on native A1111, but you oughta be able to work around any errors with some contextual clues. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@torch.no_grad() | |
def sample_dpmpp_2m_sde_adaptive(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): | |
#DPM++ 2M SDE steps, with adaptive 2S ancestral sampling and cosine similarity merging between the two. | |
extra_args = {} if extra_args is None else extra_args | |
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() | |
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler | |
s_in = x.new_ones([x.shape[0]]) | |
sigma_fn = lambda t: t.neg().exp() | |
t_fn = lambda sigma: sigma.log().neg() | |
pid = PIDStepSizeController(0.05, 0., 1., 0., 1.5, 0.71) | |
old_denoised = None | |
cos_sim = torch.nn.CosineSimilarity() | |
x_prev = x | |
h_last = None | |
for i in trange(len(sigmas) - 1, disable=disable): | |
timestep = torch.linspace(0, math.pi, torch.numel(sigmas), device=x.device) | |
timestep_fac = torch.cos(timestep) * 1 / 2 + 1 / 2 | |
denoised = model(x, sigmas[i] * s_in, **extra_args) | |
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) | |
if callback is not None: | |
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | |
t, s = -sigmas[i].log(), -sigmas[i + 1].log() | |
h = s - t | |
if sigmas[i + 1] == 0: | |
# Denoising step | |
d = to_d(x, sigmas[i], denoised) | |
dt = sigmas[i + 1] - sigmas[i] | |
x = x + d * dt | |
else: | |
s_time = sigmas[i + 1] / sigma_max | |
eta_h = eta * h | |
d = to_d(x, sigmas[i], denoised) | |
accept = False | |
x_check = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised # 2M SDE | |
if old_denoised is not None: | |
r = h_last / h | |
# Compute cosine similarity | |
x_norm = torch.nn.functional.normalize(x, p=2, dim=0) | |
x_check_norm = torch.nn.functional.normalize(x_check, p=2, dim=0) | |
simab = cos_sim(x_norm, x_check_norm) | |
dot_product = torch.dot(x_norm.view(-1), x_check_norm.view(-1)) | |
magnitude_similarity = dot_product / (torch.norm(x_norm) * torch.norm(x_check_norm)) | |
combined_similarity = (simab + magnitude_similarity) / 2.0 | |
k = (combined_similarity - x_check.min()) / (x_check.max() - x_check.min()) | |
k = k[0,:,:].clip(min=0,max=1.0) | |
x_check_pre = (1 - k) * x_check + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) * (1 - k) # 2M SDE | |
# DPM++ 2S vars | |
fac = 1 / (2 * r) | |
t_2s, t_next_2s = t_fn(sigmas[i]), t_fn(sigma_down) | |
h_2s = t_next_2s - t_2s | |
s_2s = t_2s + 0.5 * h_2s | |
x_check_pre2 = k * (sigma_fn(t_next_2s) / sigma_fn(t_2s)) * x - (t_2s - t_next_2s).expm1() * denoised * k | |
x_check = x_check_pre + x_check_pre2 | |
delta = torch.maximum(torch.tensor(0.0078), 0.05 * torch.maximum(x.abs(), x_prev.abs())) # Adaptive noise addition and denoise | |
error = torch.linalg.norm((x - x_check) / delta) / x.numel() ** 0.5 | |
accept = pid.propose_step(error) | |
if accept: | |
x_prev = x | |
x_check = x_check + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up * (1. - timestep_fac[i]) | |
x_check = x_check + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise * timestep_fac[i] # 2M SDE Noise | |
denoised_2 = model(x_check, sigma_fn(s_2s) * s_in, **extra_args) # 2S | |
x = (sigma_fn(t_next_2s) / sigma_fn(t_2s)) * x - (-h_2s).expm1() * denoised_2 | |
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up * (1. - timestep_fac[i]) | |
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise * timestep_fac[i] | |
else: | |
x_prev = x | |
# DPM++ 2S a Noise | |
x = x_check + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up | |
else: | |
x_prev = x | |
# DPM++ 2S a Noise | |
x = x_check + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up | |
old_denoised = denoised | |
h_last = h | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment