Last active
September 23, 2023 01:03
-
-
Save laksjdjf/3c6d5f4093c3c9fd0226bb71cb089049 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
1. put this file in ComfyUI/custom_nodes | |
2. load node from <loader> | |
''' | |
import torch | |
from comfy.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, timestep_embedding, th | |
# https://github.com/ChenyangSi/FreeU | |
def Fourier_filter(x, threshold, scale): | |
if scale == 1.0: | |
return x | |
org_dtype = x.dtype | |
x = x.to(torch.float32) | |
# FFT | |
x_freq = torch.fft.fftn(x, dim=(-2, -1)) | |
x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1)) | |
B, C, H, W = x_freq.shape | |
mask = torch.ones((B, C, H, W)).cuda() | |
crow, ccol = H // 2, W //2 | |
mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale | |
x_freq = x_freq * mask | |
# IFFT | |
x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1)) | |
x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real | |
x_filtered = x_filtered.to(org_dtype) | |
return x_filtered | |
class FreeU: | |
@classmethod | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"model": ("MODEL", ), | |
"b1": ("FLOAT", { | |
"default": 1.0, | |
"min": -5.0, # Minimum value | |
"max": 5.0, # Maximum value | |
"step": 0.01 # Slider's step | |
}), | |
"b2": ("FLOAT", { | |
"default": 1.0, | |
"min": -5.0, # Minimum value | |
"max": 5.0, # Maximum value | |
"step": 0.01 # Slider's step | |
}), | |
"s1": ("FLOAT", { | |
"default": 1.0, | |
"min": -5.0, # Minimum value | |
"max": 5.0, # Maximum value | |
"step": 0.01 # Slider's step | |
}), | |
"s2": ("FLOAT", { | |
"default": 1.0, | |
"min": -5.0, # Minimum value | |
"max": 5.0, # Maximum value | |
"step": 0.01 # Slider's step | |
}), | |
"threshold": ("INT", { | |
"default": 1, | |
"min": 0, | |
"max": 100, | |
"step": 1, | |
"display": "number" | |
}), | |
"start": ("INT", { | |
"default": 0, | |
"min": 0, | |
"max": 1000, | |
"step": 1, | |
"display": "number" | |
}), | |
"end": ("INT", { | |
"default": 1000, | |
"min": 0, | |
"max": 1000, | |
"step": 1, | |
"display": "number" | |
}), | |
}, | |
} | |
RETURN_TYPES = ("MODEL", ) | |
FUNCTION = "apply" | |
CATEGORY = "loaders" | |
def apply(self, model, b1, b2, s1, s2, threshold, start, end): | |
self.b1 = b1 | |
self.b2 = b2 | |
self.s1 = s1 | |
self.s2 = s2 | |
self.threshold = threshold | |
new_model = model.clone() | |
self.start = 1000 - start | |
self.end = 1000 - end | |
def apply_model(model_function, kwargs): | |
enable = self.end <= kwargs["timestep"][0] <= self.start | |
####################################################### | |
###https://github.com/comfyanonymous/ComfyUI/blob/29ccf9f471e3b2ad4f4a08ba9f34698d357f8547/comfy/model_base.py#L51 | |
####################################################### | |
x = kwargs["input"] | |
t = kwargs["timestep"] | |
c_concat = kwargs["c"].get("c_concat", None) | |
c_crossattn = kwargs["c"].get("c_crossattn", None) | |
c_adm = kwargs["c"].get("c_adm", None) | |
control = kwargs["c"].get("control", None) | |
transformer_options = kwargs["c"].get("transformer_options", None) | |
if c_concat is not None: | |
xc = torch.cat([x] + [c_concat], dim=1) | |
else: | |
xc = x | |
context = c_crossattn | |
dtype = new_model.model.get_dtype() | |
xc = xc.to(dtype) | |
t = t.to(dtype) | |
context = context.to(dtype) | |
if c_adm is not None: | |
c_adm = c_adm.to(dtype) | |
####################################################### | |
###https://github.com/comfyanonymous/ComfyUI/blob/29ccf9f471e3b2ad4f4a08ba9f34698d357f8547/comfy/ldm/modules/diffusionmodules/openaimodel.py#L600 | |
####################################################### | |
unet = new_model.model.diffusion_model | |
x = xc | |
timesteps = t | |
y = c_adm | |
transformer_options["original_shape"] = list(x.shape) | |
transformer_options["current_index"] = 0 | |
assert (y is not None) == ( | |
unet.num_classes is not None | |
), "must specify y if and only if the model is class-conditional" | |
hs = [] | |
t_emb = timestep_embedding(timesteps, unet.model_channels, repeat_only=False).to(unet.dtype) | |
emb = unet.time_embed(t_emb) | |
if unet.num_classes is not None: | |
assert y.shape[0] == x.shape[0] | |
emb = emb + unet.label_emb(y) | |
h = x.type(unet.dtype) | |
for id, module in enumerate(unet.input_blocks): | |
transformer_options["block"] = ("input", id) | |
h = forward_timestep_embed(module, h, emb, context, transformer_options) | |
hs.append(h) | |
if control is not None and 'input' in control and len(control['input']) > 0: | |
ctrl = control['input'].pop() | |
if ctrl is not None: | |
h += ctrl | |
transformer_options["block"] = ("middle", 0) | |
h = forward_timestep_embed(unet.middle_block, h, emb, context, transformer_options) | |
if control is not None and 'middle' in control and len(control['middle']) > 0: | |
ctrl = control['middle'].pop() | |
if ctrl is not None: | |
h += ctrl | |
for id, module in enumerate(unet.output_blocks): | |
transformer_options["block"] = ("output", id) | |
hsp = hs.pop() | |
# --------------- FreeU code ----------------------- | |
if enable: | |
# Only operate on the first two stages | |
if h.shape[1] == 1280: | |
h[:,:640] = h[:,:640] * self.b1 | |
hsp = Fourier_filter(hsp, threshold=self.threshold, scale=self.s1) | |
if h.shape[1] == 640: | |
h[:,:320] = h[:,:320] * self.b2 | |
hsp = Fourier_filter(hsp, threshold=self.threshold, scale=self.s2) | |
# --------------------------------------------------------- | |
if control is not None and 'output' in control and len(control['output']) > 0: | |
ctrl = control['output'].pop() | |
if ctrl is not None: | |
hsp += ctrl | |
h = th.cat([h, hsp], dim=1) | |
del hsp | |
if len(hs) > 0: | |
output_shape = hs[-1].shape | |
else: | |
output_shape = None | |
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape) | |
h = h.type(x.dtype) | |
if unet.predict_codebook_ids: | |
return unet.id_predictor(h).float() | |
else: | |
return unet.out(h).float() | |
new_model.set_model_unet_function_wrapper(apply_model) | |
return (new_model, ) | |
NODE_CLASS_MAPPINGS = { | |
"FreeU": FreeU, | |
} | |
NODE_DISPLAY_NAME_MAPPINGS = { | |
"FreeU": "Apply FreeU", | |
} | |
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment