Skip to content

Instantly share code, notes, and snippets.

@marduk191
Last active July 21, 2025 00:50
Show Gist options
  • Select an option

  • Save marduk191/6d5e3ab5d65dd28c1d4f91552eff5762 to your computer and use it in GitHub Desktop.

Select an option

Save marduk191/6d5e3ab5d65dd28c1d4f91552eff5762 to your computer and use it in GitHub Desktop.
Fast unet merger
"""
https://github.com/marduk191
"""
import torch
import gc
from typing import Dict, Any, Tuple, Optional
import comfy.model_management as model_management
import comfy.utils
class FastUNetMerger:
"""
Ultra-fast, low-memory UNet merger for ComfyUI
Uses in-place operations and optimized memory management
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model_a": ("MODEL",),
"model_b": ("MODEL",),
"ratio": ("FLOAT", {
"default": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"display": "slider"
}),
"merge_mode": (["lerp", "add", "subtract", "multiply"], {
"default": "lerp"
}),
"skip_clip": ("BOOLEAN", {
"default": True,
"tooltip": "Skip CLIP merging for faster processing"
}),
"skip_vae": ("BOOLEAN", {
"default": True,
"tooltip": "Skip VAE merging for faster processing"
}),
"use_fp16": ("BOOLEAN", {
"default": True,
"tooltip": "Use half precision for memory savings"
}),
"aggressive_cleanup": ("BOOLEAN", {
"default": True,
"tooltip": "Aggressive memory cleanup between operations"
})
},
"optional": {
"block_weights": ("STRING", {
"default": "",
"tooltip": "Comma-separated weights for each block (e.g., 0.5,0.3,0.7)"
})
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge_models"
CATEGORY = "model_merging"
def __init__(self):
self.device = model_management.get_torch_device()
self.offload_device = model_management.unet_offload_device()
def parse_block_weights(self, block_weights_str: str, num_blocks: int) -> Optional[list]:
"""Parse block weights string into list of floats"""
if not block_weights_str.strip():
return None
try:
weights = [float(w.strip()) for w in block_weights_str.split(',')]
if len(weights) != num_blocks:
print(f"Warning: Expected {num_blocks} weights, got {len(weights)}. Using ratio for all blocks.")
return None
return weights
except ValueError:
print("Warning: Invalid block weights format. Using ratio for all blocks.")
return None
def merge_tensor_inplace(self, tensor_a: torch.Tensor, tensor_b: torch.Tensor,
ratio: float, mode: str, target_tensor: torch.Tensor) -> None:
"""Merge tensors in-place to minimize memory allocation"""
if mode == "lerp":
# target = a * (1 - ratio) + b * ratio
target_tensor.copy_(tensor_a)
target_tensor.mul_(1.0 - ratio)
target_tensor.add_(tensor_b, alpha=ratio)
elif mode == "add":
target_tensor.copy_(tensor_a)
target_tensor.add_(tensor_b, alpha=ratio)
elif mode == "subtract":
target_tensor.copy_(tensor_a)
target_tensor.sub_(tensor_b, alpha=ratio)
elif mode == "multiply":
target_tensor.copy_(tensor_a)
target_tensor.mul_(tensor_b * ratio + (1.0 - ratio))
def get_model_keys(self, model_dict: Dict[str, torch.Tensor]) -> set:
"""Get all keys from model state dict"""
return set(model_dict.keys())
def merge_models(self, model_a, model_b, ratio: float, merge_mode: str,
skip_clip: bool, skip_vae: bool, use_fp16: bool,
aggressive_cleanup: bool, block_weights: str = "") -> Tuple[Any]:
"""
Merge two models with optimized memory usage
"""
print(f"Starting fast UNet merge with ratio {ratio}, mode: {merge_mode}")
# Clone model_a as base
merged_model = model_a.clone()
# Get state dicts
state_dict_a = model_a.model.state_dict()
state_dict_b = model_b.model.state_dict()
# Get common keys
keys_a = self.get_model_keys(state_dict_a)
keys_b = self.get_model_keys(state_dict_b)
common_keys = keys_a.intersection(keys_b)
if not common_keys:
print("Warning: No common keys found between models")
return (merged_model,)
# Filter keys based on skip options
filtered_keys = []
for key in common_keys:
if skip_clip and ("clip" in key.lower() or "text" in key.lower()):
continue
if skip_vae and ("vae" in key.lower() or "decoder" in key.lower() or "encoder" in key.lower()):
continue
filtered_keys.append(key)
print(f"Merging {len(filtered_keys)} parameters...")
# Parse block weights if provided
block_weights_list = self.parse_block_weights(block_weights, len(filtered_keys))
# Get merged model state dict
merged_state_dict = merged_model.model.state_dict()
# Process in batches to manage memory
batch_size = 10 # Process 10 tensors at a time
processed = 0
for i in range(0, len(filtered_keys), batch_size):
batch_keys = filtered_keys[i:i + batch_size]
for j, key in enumerate(batch_keys):
try:
tensor_a = state_dict_a[key]
tensor_b = state_dict_b[key]
# Determine merge ratio for this tensor
current_ratio = ratio
if block_weights_list:
current_ratio = block_weights_list[processed + j]
# Convert to half precision if requested
if use_fp16 and tensor_a.dtype == torch.float32:
tensor_a = tensor_a.half()
tensor_b = tensor_b.half()
# Ensure tensors are on the same device
if tensor_a.device != tensor_b.device:
tensor_b = tensor_b.to(tensor_a.device)
# Get target tensor from merged model
target_tensor = merged_state_dict[key]
# Merge in-place
self.merge_tensor_inplace(tensor_a, tensor_b, current_ratio,
merge_mode, target_tensor)
# Convert back to original dtype if needed
if use_fp16 and merged_state_dict[key].dtype != target_tensor.dtype:
merged_state_dict[key] = target_tensor.to(merged_state_dict[key].dtype)
except Exception as e:
print(f"Error merging {key}: {e}")
continue
processed += len(batch_keys)
# Aggressive cleanup between batches
if aggressive_cleanup and i > 0:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"Processed {processed}/{len(filtered_keys)} parameters...")
# Final cleanup
if aggressive_cleanup:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("UNet merge completed successfully!")
return (merged_model,)
# Register the node
NODE_CLASS_MAPPINGS = {
"FastUNetMerger": FastUNetMerger
}
NODE_DISPLAY_NAME_MAPPINGS = {
"FastUNetMerger": "Fast UNet Merger"
}
@marduk191
Copy link
Author

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment