Skip to content

Instantly share code, notes, and snippets.

@Clybius
Created April 26, 2023 16:07
Show Gist options
  • Save Clybius/ac0782467beb4b3df1a7caff13737b08 to your computer and use it in GitHub Desktop.
Save Clybius/ac0782467beb4b3df1a7caff13737b08 to your computer and use it in GitHub Desktop.
DPM++ 2M LSA (Low Sigma Ancestral) Sampler (K-Diffusion)
@torch.no_grad()
def sample_dpmpp_2m_lsa(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""DPM-Solver++(2M) with ancestral noise at sigmas below 1.1."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
old_denoised = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if old_denoised is None:
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
elif sigmas[i + 1] == 0 or sigma_down == 0:
x = (sigma_fn(t_fn(sigma_down)) / sigma_fn(t)) * x - (-h).expm1() * denoised
elif sigmas[i + 1] > 0 and sigmas[i + 1] < 1.1:
h = t_fn(sigma_down) - t
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (sigma_fn(t_fn(sigma_down)) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
old_denoised = denoised
return x
@Clybius
Copy link
Author

Clybius commented Apr 26, 2023

"This code is FUCKED!!!"

i know, i'll work on it over time (hopefully)

You can place this at the bottom of k-diffusion/k_diffusion/sampling.py.

You must use a different sampling scheduler than normal, such as the Karras scheduler, or else you will get latent burning!

Alternatively, you can lower the eta value. 1.0 is rather aggressive, and is only intended for high-step sampling and with schedulers such as the Karras scheduler or Polyexponential (personal fav) scheduler.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment