Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Last active September 23, 2023 01:03
Show Gist options
  • Save laksjdjf/3c6d5f4093c3c9fd0226bb71cb089049 to your computer and use it in GitHub Desktop.
Save laksjdjf/3c6d5f4093c3c9fd0226bb71cb089049 to your computer and use it in GitHub Desktop.
'''
1. put this file in ComfyUI/custom_nodes
2. load node from <loader>
'''
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, timestep_embedding, th
# https://github.com/ChenyangSi/FreeU
def Fourier_filter(x, threshold, scale):
if scale == 1.0:
return x
org_dtype = x.dtype
x = x.to(torch.float32)
# FFT
x_freq = torch.fft.fftn(x, dim=(-2, -1))
x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
B, C, H, W = x_freq.shape
mask = torch.ones((B, C, H, W)).cuda()
crow, ccol = H // 2, W //2
mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
x_freq = x_freq * mask
# IFFT
x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real
x_filtered = x_filtered.to(org_dtype)
return x_filtered
class FreeU:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"b1": ("FLOAT", {
"default": 1.0,
"min": -5.0, # Minimum value
"max": 5.0, # Maximum value
"step": 0.01 # Slider's step
}),
"b2": ("FLOAT", {
"default": 1.0,
"min": -5.0, # Minimum value
"max": 5.0, # Maximum value
"step": 0.01 # Slider's step
}),
"s1": ("FLOAT", {
"default": 1.0,
"min": -5.0, # Minimum value
"max": 5.0, # Maximum value
"step": 0.01 # Slider's step
}),
"s2": ("FLOAT", {
"default": 1.0,
"min": -5.0, # Minimum value
"max": 5.0, # Maximum value
"step": 0.01 # Slider's step
}),
"threshold": ("INT", {
"default": 1,
"min": 0,
"max": 100,
"step": 1,
"display": "number"
}),
"start": ("INT", {
"default": 0,
"min": 0,
"max": 1000,
"step": 1,
"display": "number"
}),
"end": ("INT", {
"default": 1000,
"min": 0,
"max": 1000,
"step": 1,
"display": "number"
}),
},
}
RETURN_TYPES = ("MODEL", )
FUNCTION = "apply"
CATEGORY = "loaders"
def apply(self, model, b1, b2, s1, s2, threshold, start, end):
self.b1 = b1
self.b2 = b2
self.s1 = s1
self.s2 = s2
self.threshold = threshold
new_model = model.clone()
self.start = 1000 - start
self.end = 1000 - end
def apply_model(model_function, kwargs):
enable = self.end <= kwargs["timestep"][0] <= self.start
#######################################################
###https://github.com/comfyanonymous/ComfyUI/blob/29ccf9f471e3b2ad4f4a08ba9f34698d357f8547/comfy/model_base.py#L51
#######################################################
x = kwargs["input"]
t = kwargs["timestep"]
c_concat = kwargs["c"].get("c_concat", None)
c_crossattn = kwargs["c"].get("c_crossattn", None)
c_adm = kwargs["c"].get("c_adm", None)
control = kwargs["c"].get("control", None)
transformer_options = kwargs["c"].get("transformer_options", None)
if c_concat is not None:
xc = torch.cat([x] + [c_concat], dim=1)
else:
xc = x
context = c_crossattn
dtype = new_model.model.get_dtype()
xc = xc.to(dtype)
t = t.to(dtype)
context = context.to(dtype)
if c_adm is not None:
c_adm = c_adm.to(dtype)
#######################################################
###https://github.com/comfyanonymous/ComfyUI/blob/29ccf9f471e3b2ad4f4a08ba9f34698d357f8547/comfy/ldm/modules/diffusionmodules/openaimodel.py#L600
#######################################################
unet = new_model.model.diffusion_model
x = xc
timesteps = t
y = c_adm
transformer_options["original_shape"] = list(x.shape)
transformer_options["current_index"] = 0
assert (y is not None) == (
unet.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, unet.model_channels, repeat_only=False).to(unet.dtype)
emb = unet.time_embed(t_emb)
if unet.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + unet.label_emb(y)
h = x.type(unet.dtype)
for id, module in enumerate(unet.input_blocks):
transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options)
hs.append(h)
if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop()
if ctrl is not None:
h += ctrl
transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(unet.middle_block, h, emb, context, transformer_options)
if control is not None and 'middle' in control and len(control['middle']) > 0:
ctrl = control['middle'].pop()
if ctrl is not None:
h += ctrl
for id, module in enumerate(unet.output_blocks):
transformer_options["block"] = ("output", id)
hsp = hs.pop()
# --------------- FreeU code -----------------------
if enable:
# Only operate on the first two stages
if h.shape[1] == 1280:
h[:,:640] = h[:,:640] * self.b1
hsp = Fourier_filter(hsp, threshold=self.threshold, scale=self.s1)
if h.shape[1] == 640:
h[:,:320] = h[:,:320] * self.b2
hsp = Fourier_filter(hsp, threshold=self.threshold, scale=self.s2)
# ---------------------------------------------------------
if control is not None and 'output' in control and len(control['output']) > 0:
ctrl = control['output'].pop()
if ctrl is not None:
hsp += ctrl
h = th.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
h = h.type(x.dtype)
if unet.predict_codebook_ids:
return unet.id_predictor(h).float()
else:
return unet.out(h).float()
new_model.set_model_unet_function_wrapper(apply_model)
return (new_model, )
NODE_CLASS_MAPPINGS = {
"FreeU": FreeU,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"FreeU": "Apply FreeU",
}
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment