Last active
November 6, 2022 18:39
-
-
Save Birch-san/e995e76b42bb8c27d16e992398f5cf4b to your computer and use it in GitHub Desktop.
Dynamic thresholding of stable-diffusion latents, by referring to known-good CFG7.5's dynamic range
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import Tensor, FloatTensor | |
from typing import Protocol, Optional | |
from k_diffusion.external import CompVisDenoiser | |
from k_diffusion.sampling import sample_heun | |
class DiffusionModel(Protocol): | |
def __call__(self, x: Tensor, sigma: Tensor, **kwargs) -> Tensor: ... | |
class DiffusionModelMixin(DiffusionModel): | |
inner_model: DiffusionModel | |
# workaround until k-diffusion introduces official base model wrapper, | |
# to make the wrapper forward all method calls to the wrapped model | |
# https://github.com/crowsonkb/k-diffusion/pull/23#issuecomment-1239937951 | |
class CompVisDenoiserWrapper(CompVisDenoiser, DiffusionModelMixin): | |
inner_model: DiffusionModel | |
def __init__(self, model: DiffusionModel, quantize=False): | |
CompVisDenoiser.__init__(self, model, quantize=quantize) | |
DiffusionModelMixin.__init__(self) | |
class BaseModelWrapper(nn.Module, DiffusionModelMixin): | |
inner_model: DiffusionModel | |
def __init__(self, inner_model: DiffusionModel): | |
super().__init__() | |
self.inner_model = inner_model | |
DiffusionModelMixin.__init__(self) | |
def repeat_along_dim_0(t: Tensor, factor: int) -> Tensor: | |
""" | |
Repeats a tensor's contents along its 0th dim `factor` times. | |
repeat_along_dim_0(torch.tensor([[0,1]]), 2) | |
tensor([[0, 1], | |
[0, 1]]) | |
# shape changes from (1, 2) | |
# to (2, 2) | |
repeat_along_dim_0(torch.tensor([[0,1],[2,3]]), 2) | |
tensor([[0, 1], | |
[2, 3], | |
[0, 1], | |
[2, 3]]) | |
# shape changes from (2, 2) | |
# to (4, 2) | |
""" | |
assert factor >= 1 | |
if factor == 1: | |
return t | |
if t.size(dim=0) == 1: | |
# prefer expand() whenever we can, since doesn't copy | |
return t.expand(factor * t.size(dim=0), *(-1,)*(t.ndim-1)) | |
return t.repeat((factor, *(1,)*(t.ndim-1))) | |
class CFGDynTheshDenoiser(BaseModelWrapper): | |
dynamic_thresholding_percentile: float | |
dynamic_thresholding_mimic_scale: float | |
def __init__( | |
self, | |
model: DiffusionModel, | |
dynamic_thresholding_percentile: float, | |
dynamic_thresholding_mimic_scale | |
): | |
super().__init__(model) | |
self.dynamic_thresholding_percentile = dynamic_thresholding_percentile | |
self.dynamic_thresholding_mimic_scale = dynamic_thresholding_mimic_scale | |
def forward( | |
self, | |
x: FloatTensor, | |
sigma: FloatTensor, | |
cond: FloatTensor, | |
cond_scale: float = 1.0, | |
uncond: Optional[FloatTensor] = None, | |
**kwargs | |
) -> FloatTensor: | |
if uncond is None or cond_scale == 1.0: | |
return self.inner_model(x, sigma, cond=cond) | |
cond_in = torch.cat([uncond, cond]) | |
del uncond, cond | |
x_in = repeat_along_dim_0(x, cond_in.size(dim=0)) | |
del x | |
sigma_in = repeat_along_dim_0(sigma, cond_in.size(dim=0)) | |
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(cond_in.size(dim=0)) | |
del x_in, sigma_in, cond_in | |
diff: Tensor = cond - uncond | |
dynthresh_target: Tensor = uncond + diff * self.dynamic_thresholding_mimic_scale | |
dt_flattened: Tensor = dynthresh_target.flatten(2) | |
dt_means: Tensor = dt_flattened.mean(dim=2).unsqueeze(2) | |
dt_recentered: Tensor = dt_flattened-dt_means | |
dt_abs = dt_recentered.abs() | |
dt_max = dt_abs.max(dim=2).values.unsqueeze(2) | |
ut: Tensor = uncond + diff * cond_scale | |
ut_flattened: Tensor = ut.flatten(2) | |
ut_means: Tensor = ut_flattened.mean(dim=2).unsqueeze(2) | |
ut_centered: Tensor = ut_flattened-ut_means | |
a = ut_centered.abs() | |
ut_q = torch.quantile(a, self.dynamic_thresholding_percentile, dim=2).unsqueeze(2) | |
s = torch.maximum(ut_q, dt_max) | |
t_clamped = ut_centered.clamp(-s, s) | |
t_normalized = t_clamped / s | |
t_renormalized = t_normalized * dt_max | |
uncentered: Tensor = t_renormalized+ut_means | |
unflattened: Tensor = uncentered.unflatten(2, dynthresh_target.shape[2:]) | |
return unflattened | |
model = # put your LatentDiffusionModel here | |
model_k_wrapped = CompVisDenoiserWrapper(model, quantize=True) | |
model_k_guidance = CFGDynTheshDenoiser( | |
model_k_wrapped, | |
# clamp away latent values that exceed the 99.9%ile | |
dynamic_thresholding_percentile=0.999, | |
# use CFG7.5 as our reference for a "known-good" dynamic range | |
dynamic_thresholding_mimic_scale=7.5, | |
) | |
# now sample from model_k_guidance the same way you'd sample from any k-diffusion wrapped model | |
sample_heun(model_k_guidance, x, ...) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment