Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Last active June 24, 2025 11:07
Show Gist options
  • Save laksjdjf/5335a4299b97c6b6fbae65a557620166 to your computer and use it in GitHub Desktop.
Save laksjdjf/5335a4299b97c6b6fbae65a557620166 to your computer and use it in GitHub Desktop.
'''
https://github.com/Zehong-Ma/MagCache
'''
from comfy.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, timestep_embedding, th, apply_control
import comfy.patcher_extension
import json
def linear_interpolate(data: dict, num: float, scale: float) -> float:
if not data:
raise ValueError("データが空です")
sorted_keys = sorted(data.keys())
# num が範囲外なら最も近い値を返す(外挿せず)
if num <= sorted_keys[0]:
retval = data[sorted_keys[0]]
if num >= sorted_keys[-1]:
retval = data[sorted_keys[-1]]
# 補間対象の区間を探す
for i in range(len(sorted_keys) - 1):
x1, x2 = sorted_keys[i], sorted_keys[i + 1]
if x1 <= num <= x2:
y1, y2 = data[x1], data[x2]
# 線形補間
retval = (y1 + (y2 - y1) * (num - x1) / (x2 - x1))
return retval ** scale
class MagCache:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"cache_depth": ("INT", {
"default": 3,
"min": 0,
"max": 12,
"step": 1,
"display": "number"
}),
"calibration": ("BOOLEAN", {"default": False}),
"k": ("INT", {
"default": 6,
"min": 0,
"max": 100,
"step": 1,
"display": "number"
}),
"threshold": ("FLOAT", {
"default": 0.01,
"min": -0.01,
"max": 10,
"step": 0.001,
"round": 0.001,
}),
"retention": ("FLOAT", {
"default": 0.2,
"min": 0,
"max": 1.0,
"step": 0.01,
"round": 0.01,
}),
"ratios": ("STRING", {"multiline": True, "default": ""}),
},
}
RETURN_TYPES = ("MODEL", )
FUNCTION = "apply"
CATEGORY = "loaders"
def apply(self, model, cache_depth, calibration, k, threshold, retention, ratios):
new_model = model.clone()
residual = None
accumulated_ratio = 1.0
accumulated_error = 0.0
accumulated_steps = 0
skip_forwards = []
assert calibration or ratios != "", "ratios must be specified if not calibration"
if ratios != "" and not calibration:
ratios = json.loads(ratios)
ratios = {int(k): v for k, v in ratios.items()}
else:
ratios = {}
cos_sims = {}
def hook_forward(self):
def _forward(x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
######## Calc current step and skip forward
nonlocal new_model, cache_depth, calibration, k, threshold, retention, ratios, cos_sims,residual, accumulated_ratio, accumulated_error, accumulated_steps, skip_forwards
sample_sigmas = transformer_options["sample_sigmas"]
sigmas = transformer_options["sigmas"]
step = (sample_sigmas - sigmas[0].item()).abs().argmin().item()
timestep_total = new_model.model.model_sampling.timestep(sample_sigmas)
timestep = timestep_total[step].item()
total_steps = sample_sigmas.shape[0] - 1
if step == 0:
residual = None
accumulated_ratio = 1.0
accumulated_error = 0.0
accumulated_steps = 0
skip_forwards = []
if not calibration:
for i in range(total_steps):
if i < max(1, retention * (sample_sigmas.shape[0] - 1)):
skip_forwards.append(False)
else:
ratio = linear_interpolate(ratios, timestep_total[i].item(), (len(ratios) - 1) / (total_steps - 1))
accumulated_ratio *= ratio
error = abs(accumulated_ratio - 1.0)
accumulated_error += error
accumulated_steps += 1
if accumulated_error <= threshold and accumulated_steps <= k:
skip_forwards.append(True)
else:
skip_forwards.append(False)
accumulated_ratio = 1.0
accumulated_error = 0.0
accumulated_steps = 0
skip_indices = [i for i, skip in enumerate(skip_forwards) if skip]
print(f"Skip {len(skip_indices)} steps ({skip_indices})")
skip_forward = False if calibration else skip_forwards[step]
########
transformer_options["original_shape"] = list(x.shape)
transformer_options["transformer_index"] = 0
transformer_patches = transformer_options.get("patches", {})
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", None)
time_context = kwargs.get("time_context", None)
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
if "emb_patch" in transformer_patches:
patch = transformer_patches["emb_patch"]
for p in patch:
emb = p(emb, self.model_channels, transformer_options)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'input')
if "input_block_patch" in transformer_patches:
patch = transformer_patches["input_block_patch"]
for p in patch:
h = p(h, transformer_options)
hs.append(h)
if "input_block_patch_after_skip" in transformer_patches:
patch = transformer_patches["input_block_patch_after_skip"]
for p in patch:
h = p(h, transformer_options)
# Skip forward if needed
if skip_forward and id == cache_depth:
break
####
if not skip_forward:
transformer_options["block"] = ("middle", 0)
if self.middle_block is not None:
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle')
for id, module in enumerate(self.output_blocks):
if skip_forward and id < len(self.output_blocks) - cache_depth - 1:
continue
transformer_options["block"] = ("output", id)
hsp = hs.pop()
hsp = apply_control(hsp, control, 'output')
if "output_block_patch" in transformer_patches:
patch = transformer_patches["output_block_patch"]
for p in patch:
h, hsp = p(h, hsp, transformer_options)
### Residual
if id == len(self.output_blocks) - cache_depth - 1:
if calibration and residual is not None:
ratios[timestep] = (h.norm(dim=1) / residual.norm(dim=1)).mean().item()
#cos_sims[timestep] = th.nn.functional.cosine_similarity(h, residual, dim=1).mean().item()
if skip_forward:
h = residual.clone()
else:
residual = h.clone()
###
h = th.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = h.type(x.dtype)
### Print ratios for calibration
if step == sample_sigmas.shape[0] - 2 and calibration:
print("Copy below data to ratios")
print(json.dumps(ratios))
#print("Copy below data to cos_sims")
#print(json.dumps(cos_sims))
ratios = {}
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)
return _forward
def wrapper_1(executor, *args, **kwargs):
org_forward = new_model.model.diffusion_model._forward
new_model.model.diffusion_model._forward = hook_forward(new_model.model.diffusion_model)
result = executor(*args, **kwargs)
new_model.model.diffusion_model._forward = org_forward
return result
w = new_model.wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {}).setdefault(None, [])
w.append(wrapper_1)
return (new_model, )
NODE_CLASS_MAPPINGS = {
"MagCache": MagCache,
}
__all__ = ["NODE_CLASS_MAPPINGS"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment