Created
April 26, 2023 16:07
-
-
Save Clybius/ac0782467beb4b3df1a7caff13737b08 to your computer and use it in GitHub Desktop.
DPM++ 2M LSA (Low Sigma Ancestral) Sampler (K-Diffusion)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@torch.no_grad() | |
def sample_dpmpp_2m_lsa(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): | |
"""DPM-Solver++(2M) with ancestral noise at sigmas below 1.1.""" | |
extra_args = {} if extra_args is None else extra_args | |
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler | |
s_in = x.new_ones([x.shape[0]]) | |
sigma_fn = lambda t: t.neg().exp() | |
t_fn = lambda sigma: sigma.log().neg() | |
old_denoised = None | |
for i in trange(len(sigmas) - 1, disable=disable): | |
denoised = model(x, sigmas[i] * s_in, **extra_args) | |
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) | |
if callback is not None: | |
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | |
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) | |
h = t_next - t | |
if old_denoised is None: | |
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised | |
elif sigmas[i + 1] == 0 or sigma_down == 0: | |
x = (sigma_fn(t_fn(sigma_down)) / sigma_fn(t)) * x - (-h).expm1() * denoised | |
elif sigmas[i + 1] > 0 and sigmas[i + 1] < 1.1: | |
h = t_fn(sigma_down) - t | |
h_last = t - t_fn(sigmas[i - 1]) | |
r = h_last / h | |
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised | |
x = (sigma_fn(t_fn(sigma_down)) / sigma_fn(t)) * x - (-h).expm1() * denoised_d | |
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up | |
else: | |
h_last = t - t_fn(sigmas[i - 1]) | |
r = h_last / h | |
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised | |
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d | |
old_denoised = denoised | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
"This code is FUCKED!!!"
i know, i'll work on it over time (hopefully)
You can place this at the bottom of
k-diffusion/k_diffusion/sampling.py
.You must use a different sampling scheduler than normal, such as the Karras scheduler, or else you will get latent burning!
Alternatively, you can lower the
eta
value. 1.0 is rather aggressive, and is only intended for high-step sampling and with schedulers such as the Karras scheduler or Polyexponential (personal fav) scheduler.