Skip to content

Instantly share code, notes, and snippets.

@Clybius
Created July 3, 2023 17:55
Show Gist options
  • Save Clybius/9a4221245d27898f9c6e8c7cf4fc7eff to your computer and use it in GitHub Desktop.
Save Clybius/9a4221245d27898f9c6e8c7cf4fc7eff to your computer and use it in GitHub Desktop.
DPM++ 2M SDE Adaptive Sampler | A modified 2M SDE sampler complete with cosine similarity matching with a 2S step, and adaptive second order sampling.
# How to setup in A1111 Stable Diffusion and its various forks
1) You can install the sampler by adding it to the bottom of `repositories/k-diffusion/k_diffusion/sampling.py`
2) Then within `modules/sd_samplers_kdiffusion.py`, add the following to `samplers_k_diffusion`
('DPM++ 2M SDE Adaptive', 'sample_dpmpp_2m_sde_adaptive', ['c_dpmpp_2m_sde_ad'], {"brownian_noise": True, 'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
Notes: This has been tested on Vlad's A1111 fork. Step 2 may not be exactly the same on native A1111, but you oughta be able to work around any errors with some contextual clues.
@torch.no_grad()
def sample_dpmpp_2m_sde_adaptive(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
#DPM++ 2M SDE steps, with adaptive 2S ancestral sampling and cosine similarity merging between the two.
extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
pid = PIDStepSizeController(0.05, 0., 1., 0., 1.5, 0.71)
old_denoised = None
cos_sim = torch.nn.CosineSimilarity()
x_prev = x
h_last = None
for i in trange(len(sigmas) - 1, disable=disable):
timestep = torch.linspace(0, math.pi, torch.numel(sigmas), device=x.device)
timestep_fac = torch.cos(timestep) * 1 / 2 + 1 / 2
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
if sigmas[i + 1] == 0:
# Denoising step
d = to_d(x, sigmas[i], denoised)
dt = sigmas[i + 1] - sigmas[i]
x = x + d * dt
else:
s_time = sigmas[i + 1] / sigma_max
eta_h = eta * h
d = to_d(x, sigmas[i], denoised)
accept = False
x_check = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised # 2M SDE
if old_denoised is not None:
r = h_last / h
# Compute cosine similarity
x_norm = torch.nn.functional.normalize(x, p=2, dim=0)
x_check_norm = torch.nn.functional.normalize(x_check, p=2, dim=0)
simab = cos_sim(x_norm, x_check_norm)
dot_product = torch.dot(x_norm.view(-1), x_check_norm.view(-1))
magnitude_similarity = dot_product / (torch.norm(x_norm) * torch.norm(x_check_norm))
combined_similarity = (simab + magnitude_similarity) / 2.0
k = (combined_similarity - x_check.min()) / (x_check.max() - x_check.min())
k = k[0,:,:].clip(min=0,max=1.0)
x_check_pre = (1 - k) * x_check + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) * (1 - k) # 2M SDE
# DPM++ 2S vars
fac = 1 / (2 * r)
t_2s, t_next_2s = t_fn(sigmas[i]), t_fn(sigma_down)
h_2s = t_next_2s - t_2s
s_2s = t_2s + 0.5 * h_2s
x_check_pre2 = k * (sigma_fn(t_next_2s) / sigma_fn(t_2s)) * x - (t_2s - t_next_2s).expm1() * denoised * k
x_check = x_check_pre + x_check_pre2
delta = torch.maximum(torch.tensor(0.0078), 0.05 * torch.maximum(x.abs(), x_prev.abs())) # Adaptive noise addition and denoise
error = torch.linalg.norm((x - x_check) / delta) / x.numel() ** 0.5
accept = pid.propose_step(error)
if accept:
x_prev = x
x_check = x_check + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up * (1. - timestep_fac[i])
x_check = x_check + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise * timestep_fac[i] # 2M SDE Noise
denoised_2 = model(x_check, sigma_fn(s_2s) * s_in, **extra_args) # 2S
x = (sigma_fn(t_next_2s) / sigma_fn(t_2s)) * x - (-h_2s).expm1() * denoised_2
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up * (1. - timestep_fac[i])
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise * timestep_fac[i]
else:
x_prev = x
# DPM++ 2S a Noise
x = x_check + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
else:
x_prev = x
# DPM++ 2S a Noise
x = x_check + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
old_denoised = denoised
h_last = h
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment