Created
October 12, 2022 15:39
-
-
Save torridgristle/ca942c2e1c31ac31111d31931ed1dfbb to your computer and use it in GitHub Desktop.
Stable Diffusion CFGDenoiser with slew limiting and frequency splitting for detail preservation as an option.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torchvision.transforms.functional as TF | |
1class CFGDenoiserSlew(nn.Module): | |
''' | |
Clamps the maximum change each step can have. | |
"limit" is the clamp bounds. 0.4-0.8 seem good, 1.6 and 3.2 have very little difference and might represent the upper bound of values. | |
"blur" is the radius of a gaussian blur used to split the limited output with the original output in an attempt to preserve detail and color. | |
"last_step_is_blur" if true will compare the model output to the blur-split output rather than just the limited output, can look nicer. | |
''' | |
def __init__(self, model, limit = 0.2, blur = 5, last_step_is_blur = True): | |
super().__init__() | |
self.inner_model = model | |
self.last_sigma = 0.0 # For keeping track of when the sampling cycle restarts for a new image | |
self.last_step = None # For keeping the last step for measuring change between steps | |
self.limit = limit # The clamp bounds | |
self.blur = blur # Radius of the blur for freq splitting and merging limited and non-limited outputs | |
self.last_step_is_blur = last_step_is_blur # Compare outputs to the freq split output instead of the plain limited output | |
def forward(self, x, sigma, uncond, cond, cond_scale): | |
x_in = torch.cat([x] * 2) | |
sigma_in = torch.cat([sigma] * 2) | |
cond_in = torch.cat([uncond, cond]) | |
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) | |
result_clean = uncond + (cond - uncond) * cond_scale | |
if sigma > self.last_sigma: | |
self.last_step = None | |
self.last_sigma = sigma | |
if self.last_step != None: | |
diff = result_clean - self.last_step | |
result = diff.clamp(-1 * self.limit, self.limit) + self.last_step | |
if self.last_step_is_blur == False: | |
self.last_step = result # Pre-blur | |
if self.blur > 1: | |
result = TF.gaussian_blur(result, self.blur) | |
result_clean_hi = result_clean - TF.gaussian_blur(result_clean, self.blur) | |
result = result + result_clean_hi | |
if self.last_step_is_blur == True: | |
self.last_step = result # Post-blur | |
del result_clean_hi | |
del diff, x_in, sigma_in, cond_in, uncond, cond, result_clean | |
else: | |
result = result_clean | |
self.last_step = result | |
return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment