Last active
July 21, 2025 00:50
-
-
Save marduk191/6d5e3ab5d65dd28c1d4f91552eff5762 to your computer and use it in GitHub Desktop.
Fast unet merger
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| https://github.com/marduk191 | |
| """ | |
| import torch | |
| import gc | |
| from typing import Dict, Any, Tuple, Optional | |
| import comfy.model_management as model_management | |
| import comfy.utils | |
| class FastUNetMerger: | |
| """ | |
| Ultra-fast, low-memory UNet merger for ComfyUI | |
| Uses in-place operations and optimized memory management | |
| """ | |
| @classmethod | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "model_a": ("MODEL",), | |
| "model_b": ("MODEL",), | |
| "ratio": ("FLOAT", { | |
| "default": 0.5, | |
| "min": 0.0, | |
| "max": 1.0, | |
| "step": 0.01, | |
| "display": "slider" | |
| }), | |
| "merge_mode": (["lerp", "add", "subtract", "multiply"], { | |
| "default": "lerp" | |
| }), | |
| "skip_clip": ("BOOLEAN", { | |
| "default": True, | |
| "tooltip": "Skip CLIP merging for faster processing" | |
| }), | |
| "skip_vae": ("BOOLEAN", { | |
| "default": True, | |
| "tooltip": "Skip VAE merging for faster processing" | |
| }), | |
| "use_fp16": ("BOOLEAN", { | |
| "default": True, | |
| "tooltip": "Use half precision for memory savings" | |
| }), | |
| "aggressive_cleanup": ("BOOLEAN", { | |
| "default": True, | |
| "tooltip": "Aggressive memory cleanup between operations" | |
| }) | |
| }, | |
| "optional": { | |
| "block_weights": ("STRING", { | |
| "default": "", | |
| "tooltip": "Comma-separated weights for each block (e.g., 0.5,0.3,0.7)" | |
| }) | |
| } | |
| } | |
| RETURN_TYPES = ("MODEL",) | |
| FUNCTION = "merge_models" | |
| CATEGORY = "model_merging" | |
| def __init__(self): | |
| self.device = model_management.get_torch_device() | |
| self.offload_device = model_management.unet_offload_device() | |
| def parse_block_weights(self, block_weights_str: str, num_blocks: int) -> Optional[list]: | |
| """Parse block weights string into list of floats""" | |
| if not block_weights_str.strip(): | |
| return None | |
| try: | |
| weights = [float(w.strip()) for w in block_weights_str.split(',')] | |
| if len(weights) != num_blocks: | |
| print(f"Warning: Expected {num_blocks} weights, got {len(weights)}. Using ratio for all blocks.") | |
| return None | |
| return weights | |
| except ValueError: | |
| print("Warning: Invalid block weights format. Using ratio for all blocks.") | |
| return None | |
| def merge_tensor_inplace(self, tensor_a: torch.Tensor, tensor_b: torch.Tensor, | |
| ratio: float, mode: str, target_tensor: torch.Tensor) -> None: | |
| """Merge tensors in-place to minimize memory allocation""" | |
| if mode == "lerp": | |
| # target = a * (1 - ratio) + b * ratio | |
| target_tensor.copy_(tensor_a) | |
| target_tensor.mul_(1.0 - ratio) | |
| target_tensor.add_(tensor_b, alpha=ratio) | |
| elif mode == "add": | |
| target_tensor.copy_(tensor_a) | |
| target_tensor.add_(tensor_b, alpha=ratio) | |
| elif mode == "subtract": | |
| target_tensor.copy_(tensor_a) | |
| target_tensor.sub_(tensor_b, alpha=ratio) | |
| elif mode == "multiply": | |
| target_tensor.copy_(tensor_a) | |
| target_tensor.mul_(tensor_b * ratio + (1.0 - ratio)) | |
| def get_model_keys(self, model_dict: Dict[str, torch.Tensor]) -> set: | |
| """Get all keys from model state dict""" | |
| return set(model_dict.keys()) | |
| def merge_models(self, model_a, model_b, ratio: float, merge_mode: str, | |
| skip_clip: bool, skip_vae: bool, use_fp16: bool, | |
| aggressive_cleanup: bool, block_weights: str = "") -> Tuple[Any]: | |
| """ | |
| Merge two models with optimized memory usage | |
| """ | |
| print(f"Starting fast UNet merge with ratio {ratio}, mode: {merge_mode}") | |
| # Clone model_a as base | |
| merged_model = model_a.clone() | |
| # Get state dicts | |
| state_dict_a = model_a.model.state_dict() | |
| state_dict_b = model_b.model.state_dict() | |
| # Get common keys | |
| keys_a = self.get_model_keys(state_dict_a) | |
| keys_b = self.get_model_keys(state_dict_b) | |
| common_keys = keys_a.intersection(keys_b) | |
| if not common_keys: | |
| print("Warning: No common keys found between models") | |
| return (merged_model,) | |
| # Filter keys based on skip options | |
| filtered_keys = [] | |
| for key in common_keys: | |
| if skip_clip and ("clip" in key.lower() or "text" in key.lower()): | |
| continue | |
| if skip_vae and ("vae" in key.lower() or "decoder" in key.lower() or "encoder" in key.lower()): | |
| continue | |
| filtered_keys.append(key) | |
| print(f"Merging {len(filtered_keys)} parameters...") | |
| # Parse block weights if provided | |
| block_weights_list = self.parse_block_weights(block_weights, len(filtered_keys)) | |
| # Get merged model state dict | |
| merged_state_dict = merged_model.model.state_dict() | |
| # Process in batches to manage memory | |
| batch_size = 10 # Process 10 tensors at a time | |
| processed = 0 | |
| for i in range(0, len(filtered_keys), batch_size): | |
| batch_keys = filtered_keys[i:i + batch_size] | |
| for j, key in enumerate(batch_keys): | |
| try: | |
| tensor_a = state_dict_a[key] | |
| tensor_b = state_dict_b[key] | |
| # Determine merge ratio for this tensor | |
| current_ratio = ratio | |
| if block_weights_list: | |
| current_ratio = block_weights_list[processed + j] | |
| # Convert to half precision if requested | |
| if use_fp16 and tensor_a.dtype == torch.float32: | |
| tensor_a = tensor_a.half() | |
| tensor_b = tensor_b.half() | |
| # Ensure tensors are on the same device | |
| if tensor_a.device != tensor_b.device: | |
| tensor_b = tensor_b.to(tensor_a.device) | |
| # Get target tensor from merged model | |
| target_tensor = merged_state_dict[key] | |
| # Merge in-place | |
| self.merge_tensor_inplace(tensor_a, tensor_b, current_ratio, | |
| merge_mode, target_tensor) | |
| # Convert back to original dtype if needed | |
| if use_fp16 and merged_state_dict[key].dtype != target_tensor.dtype: | |
| merged_state_dict[key] = target_tensor.to(merged_state_dict[key].dtype) | |
| except Exception as e: | |
| print(f"Error merging {key}: {e}") | |
| continue | |
| processed += len(batch_keys) | |
| # Aggressive cleanup between batches | |
| if aggressive_cleanup and i > 0: | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"Processed {processed}/{len(filtered_keys)} parameters...") | |
| # Final cleanup | |
| if aggressive_cleanup: | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print("UNet merge completed successfully!") | |
| return (merged_model,) | |
| # Register the node | |
| NODE_CLASS_MAPPINGS = { | |
| "FastUNetMerger": FastUNetMerger | |
| } | |
| NODE_DISPLAY_NAME_MAPPINGS = { | |
| "FastUNetMerger": "Fast UNet Merger" | |
| } |
Author
marduk191
commented
Jul 5, 2025

Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment