Last active
June 24, 2025 11:07
-
-
Save laksjdjf/5335a4299b97c6b6fbae65a557620166 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
https://github.com/Zehong-Ma/MagCache | |
''' | |
from comfy.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, timestep_embedding, th, apply_control | |
import comfy.patcher_extension | |
import json | |
def linear_interpolate(data: dict, num: float, scale: float) -> float: | |
if not data: | |
raise ValueError("データが空です") | |
sorted_keys = sorted(data.keys()) | |
# num が範囲外なら最も近い値を返す(外挿せず) | |
if num <= sorted_keys[0]: | |
retval = data[sorted_keys[0]] | |
if num >= sorted_keys[-1]: | |
retval = data[sorted_keys[-1]] | |
# 補間対象の区間を探す | |
for i in range(len(sorted_keys) - 1): | |
x1, x2 = sorted_keys[i], sorted_keys[i + 1] | |
if x1 <= num <= x2: | |
y1, y2 = data[x1], data[x2] | |
# 線形補間 | |
retval = (y1 + (y2 - y1) * (num - x1) / (x2 - x1)) | |
return retval ** scale | |
class MagCache: | |
@classmethod | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"model": ("MODEL", ), | |
"cache_depth": ("INT", { | |
"default": 3, | |
"min": 0, | |
"max": 12, | |
"step": 1, | |
"display": "number" | |
}), | |
"calibration": ("BOOLEAN", {"default": False}), | |
"k": ("INT", { | |
"default": 6, | |
"min": 0, | |
"max": 100, | |
"step": 1, | |
"display": "number" | |
}), | |
"threshold": ("FLOAT", { | |
"default": 0.01, | |
"min": -0.01, | |
"max": 10, | |
"step": 0.001, | |
"round": 0.001, | |
}), | |
"retention": ("FLOAT", { | |
"default": 0.2, | |
"min": 0, | |
"max": 1.0, | |
"step": 0.01, | |
"round": 0.01, | |
}), | |
"ratios": ("STRING", {"multiline": True, "default": ""}), | |
}, | |
} | |
RETURN_TYPES = ("MODEL", ) | |
FUNCTION = "apply" | |
CATEGORY = "loaders" | |
def apply(self, model, cache_depth, calibration, k, threshold, retention, ratios): | |
new_model = model.clone() | |
residual = None | |
accumulated_ratio = 1.0 | |
accumulated_error = 0.0 | |
accumulated_steps = 0 | |
skip_forwards = [] | |
assert calibration or ratios != "", "ratios must be specified if not calibration" | |
if ratios != "" and not calibration: | |
ratios = json.loads(ratios) | |
ratios = {int(k): v for k, v in ratios.items()} | |
else: | |
ratios = {} | |
cos_sims = {} | |
def hook_forward(self): | |
def _forward(x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): | |
""" | |
Apply the model to an input batch. | |
:param x: an [N x C x ...] Tensor of inputs. | |
:param timesteps: a 1-D batch of timesteps. | |
:param context: conditioning plugged in via crossattn | |
:param y: an [N] Tensor of labels, if class-conditional. | |
:return: an [N x C x ...] Tensor of outputs. | |
""" | |
######## Calc current step and skip forward | |
nonlocal new_model, cache_depth, calibration, k, threshold, retention, ratios, cos_sims,residual, accumulated_ratio, accumulated_error, accumulated_steps, skip_forwards | |
sample_sigmas = transformer_options["sample_sigmas"] | |
sigmas = transformer_options["sigmas"] | |
step = (sample_sigmas - sigmas[0].item()).abs().argmin().item() | |
timestep_total = new_model.model.model_sampling.timestep(sample_sigmas) | |
timestep = timestep_total[step].item() | |
total_steps = sample_sigmas.shape[0] - 1 | |
if step == 0: | |
residual = None | |
accumulated_ratio = 1.0 | |
accumulated_error = 0.0 | |
accumulated_steps = 0 | |
skip_forwards = [] | |
if not calibration: | |
for i in range(total_steps): | |
if i < max(1, retention * (sample_sigmas.shape[0] - 1)): | |
skip_forwards.append(False) | |
else: | |
ratio = linear_interpolate(ratios, timestep_total[i].item(), (len(ratios) - 1) / (total_steps - 1)) | |
accumulated_ratio *= ratio | |
error = abs(accumulated_ratio - 1.0) | |
accumulated_error += error | |
accumulated_steps += 1 | |
if accumulated_error <= threshold and accumulated_steps <= k: | |
skip_forwards.append(True) | |
else: | |
skip_forwards.append(False) | |
accumulated_ratio = 1.0 | |
accumulated_error = 0.0 | |
accumulated_steps = 0 | |
skip_indices = [i for i, skip in enumerate(skip_forwards) if skip] | |
print(f"Skip {len(skip_indices)} steps ({skip_indices})") | |
skip_forward = False if calibration else skip_forwards[step] | |
######## | |
transformer_options["original_shape"] = list(x.shape) | |
transformer_options["transformer_index"] = 0 | |
transformer_patches = transformer_options.get("patches", {}) | |
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) | |
image_only_indicator = kwargs.get("image_only_indicator", None) | |
time_context = kwargs.get("time_context", None) | |
assert (y is not None) == ( | |
self.num_classes is not None | |
), "must specify y if and only if the model is class-conditional" | |
hs = [] | |
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) | |
emb = self.time_embed(t_emb) | |
if "emb_patch" in transformer_patches: | |
patch = transformer_patches["emb_patch"] | |
for p in patch: | |
emb = p(emb, self.model_channels, transformer_options) | |
if self.num_classes is not None: | |
assert y.shape[0] == x.shape[0] | |
emb = emb + self.label_emb(y) | |
h = x | |
for id, module in enumerate(self.input_blocks): | |
transformer_options["block"] = ("input", id) | |
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) | |
h = apply_control(h, control, 'input') | |
if "input_block_patch" in transformer_patches: | |
patch = transformer_patches["input_block_patch"] | |
for p in patch: | |
h = p(h, transformer_options) | |
hs.append(h) | |
if "input_block_patch_after_skip" in transformer_patches: | |
patch = transformer_patches["input_block_patch_after_skip"] | |
for p in patch: | |
h = p(h, transformer_options) | |
# Skip forward if needed | |
if skip_forward and id == cache_depth: | |
break | |
#### | |
if not skip_forward: | |
transformer_options["block"] = ("middle", 0) | |
if self.middle_block is not None: | |
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) | |
h = apply_control(h, control, 'middle') | |
for id, module in enumerate(self.output_blocks): | |
if skip_forward and id < len(self.output_blocks) - cache_depth - 1: | |
continue | |
transformer_options["block"] = ("output", id) | |
hsp = hs.pop() | |
hsp = apply_control(hsp, control, 'output') | |
if "output_block_patch" in transformer_patches: | |
patch = transformer_patches["output_block_patch"] | |
for p in patch: | |
h, hsp = p(h, hsp, transformer_options) | |
### Residual | |
if id == len(self.output_blocks) - cache_depth - 1: | |
if calibration and residual is not None: | |
ratios[timestep] = (h.norm(dim=1) / residual.norm(dim=1)).mean().item() | |
#cos_sims[timestep] = th.nn.functional.cosine_similarity(h, residual, dim=1).mean().item() | |
if skip_forward: | |
h = residual.clone() | |
else: | |
residual = h.clone() | |
### | |
h = th.cat([h, hsp], dim=1) | |
del hsp | |
if len(hs) > 0: | |
output_shape = hs[-1].shape | |
else: | |
output_shape = None | |
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) | |
h = h.type(x.dtype) | |
### Print ratios for calibration | |
if step == sample_sigmas.shape[0] - 2 and calibration: | |
print("Copy below data to ratios") | |
print(json.dumps(ratios)) | |
#print("Copy below data to cos_sims") | |
#print(json.dumps(cos_sims)) | |
ratios = {} | |
if self.predict_codebook_ids: | |
return self.id_predictor(h) | |
else: | |
return self.out(h) | |
return _forward | |
def wrapper_1(executor, *args, **kwargs): | |
org_forward = new_model.model.diffusion_model._forward | |
new_model.model.diffusion_model._forward = hook_forward(new_model.model.diffusion_model) | |
result = executor(*args, **kwargs) | |
new_model.model.diffusion_model._forward = org_forward | |
return result | |
w = new_model.wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {}).setdefault(None, []) | |
w.append(wrapper_1) | |
return (new_model, ) | |
NODE_CLASS_MAPPINGS = { | |
"MagCache": MagCache, | |
} | |
__all__ = ["NODE_CLASS_MAPPINGS"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment