Created
July 23, 2025 19:13
-
-
Save marduk191/174998c11451e10e9e94356aeb1e2961 to your computer and use it in GitHub Desktop.
combine multiple clips and vae with the model for packaging
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Script to merge separated CLIP, VAE, and model safetensor files back into a single safetensor file. | |
| https://github.com/marduk191 | |
| """ | |
| import os | |
| import json | |
| import struct | |
| import torch | |
| from pathlib import Path | |
| from typing import Dict, Any, List, Optional, Set | |
| import argparse | |
| from collections import defaultdict | |
| import re | |
| class MemoryEfficientSafeOpen: | |
| # does not support metadata loading | |
| def __init__(self, filename): | |
| self.filename = filename | |
| self.header, self.header_size = self._read_header() | |
| self.file = open(filename, "rb") | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.file.close() | |
| def keys(self): | |
| return [k for k in self.header.keys() if k != "__metadata__"] | |
| def get_tensor(self, key): | |
| if key not in self.header: | |
| raise KeyError(f"Tensor '{key}' not found in the file") | |
| metadata = self.header[key] | |
| offset_start, offset_end = metadata["data_offsets"] | |
| if offset_start == offset_end: | |
| tensor_bytes = None | |
| else: | |
| # adjust offset by header size | |
| self.file.seek(self.header_size + 8 + offset_start) | |
| tensor_bytes = self.file.read(offset_end - offset_start) | |
| return self._deserialize_tensor(tensor_bytes, metadata) | |
| def _read_header(self): | |
| with open(self.filename, "rb") as f: | |
| header_size = struct.unpack("<Q", f.read(8))[0] | |
| header_json = f.read(header_size).decode("utf-8") | |
| return json.loads(header_json), header_size | |
| def _deserialize_tensor(self, tensor_bytes, metadata): | |
| dtype = self._get_torch_dtype(metadata["dtype"]) | |
| shape = metadata["shape"] | |
| if tensor_bytes is None: | |
| byte_tensor = torch.empty(0, dtype=torch.uint8) | |
| else: | |
| tensor_bytes = bytearray(tensor_bytes) # make it writable | |
| byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8) | |
| # process float8 types | |
| if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]: | |
| return self._convert_float8(byte_tensor, metadata["dtype"], shape) | |
| # convert to the target dtype and reshape | |
| return byte_tensor.view(dtype).reshape(shape) | |
| @staticmethod | |
| def _get_torch_dtype(dtype_str): | |
| dtype_map = { | |
| "F64": torch.float64, | |
| "F32": torch.float32, | |
| "F16": torch.float16, | |
| "BF16": torch.bfloat16, | |
| "I64": torch.int64, | |
| "I32": torch.int32, | |
| "I16": torch.int16, | |
| "I8": torch.int8, | |
| "U8": torch.uint8, | |
| "BOOL": torch.bool, | |
| } | |
| # add float8 types if available | |
| if hasattr(torch, "float8_e5m2"): | |
| dtype_map["F8_E5M2"] = torch.float8_e5m2 | |
| if hasattr(torch, "float8_e4m3fn"): | |
| dtype_map["F8_E4M3"] = torch.float8_e4m3fn | |
| return dtype_map.get(dtype_str) | |
| @staticmethod | |
| def _convert_float8(byte_tensor, dtype_str, shape): | |
| if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"): | |
| return byte_tensor.view(torch.float8_e5m2).reshape(shape) | |
| elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"): | |
| return byte_tensor.view(torch.float8_e4m3fn).reshape(shape) | |
| else: | |
| raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") | |
| def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): | |
| """Memory efficient save file""" | |
| _TYPES = { | |
| torch.float64: "F64", | |
| torch.float32: "F32", | |
| torch.float16: "F16", | |
| torch.bfloat16: "BF16", | |
| torch.int64: "I64", | |
| torch.int32: "I32", | |
| torch.int16: "I16", | |
| torch.int8: "I8", | |
| torch.uint8: "U8", | |
| torch.bool: "BOOL", | |
| getattr(torch, "float8_e5m2", None): "F8_E5M2", | |
| getattr(torch, "float8_e4m3fn", None): "F8_E4M3", | |
| } | |
| _ALIGN = 256 | |
| def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: | |
| validated = {} | |
| for key, value in metadata.items(): | |
| if not isinstance(key, str): | |
| raise ValueError(f"Metadata key must be a string, got {type(key)}") | |
| if not isinstance(value, str): | |
| print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") | |
| validated[key] = str(value) | |
| else: | |
| validated[key] = value | |
| return validated | |
| print(f"Saving merged file: {filename}") | |
| header = {} | |
| offset = 0 | |
| if metadata: | |
| header["__metadata__"] = validate_metadata(metadata) | |
| for k, v in tensors.items(): | |
| if v.numel() == 0: # empty tensor | |
| header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} | |
| else: | |
| size = v.numel() * v.element_size() | |
| header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} | |
| offset += size | |
| hjson = json.dumps(header).encode("utf-8") | |
| hjson += b" " * (-(len(hjson) + 8) % _ALIGN) | |
| with open(filename, "wb") as f: | |
| f.write(struct.pack("<Q", len(hjson))) | |
| f.write(hjson) | |
| for k, v in tensors.items(): | |
| if v.numel() == 0: | |
| continue | |
| if v.is_cuda: | |
| # Direct GPU to disk save | |
| with torch.cuda.device(v.device): | |
| if v.dim() == 0: # if scalar, need to add a dimension to work with view | |
| v = v.unsqueeze(0) | |
| tensor_bytes = v.contiguous().view(torch.uint8) | |
| tensor_bytes.cpu().numpy().tofile(f) | |
| else: | |
| # CPU tensor save | |
| if v.dim() == 0: # if scalar, need to add a dimension to work with view | |
| v = v.unsqueeze(0) | |
| v.contiguous().view(torch.uint8).numpy().tofile(f) | |
| def find_component_files(base_dir: Path, base_name: str) -> Dict[str, Optional[Path]]: | |
| """Find model, clip, and vae component files for a given base name""" | |
| components = { | |
| 'model': None, | |
| 'clip': None, | |
| 'vae': None | |
| } | |
| # Search patterns for each component | |
| patterns = { | |
| 'model': [f"{base_name}_model.safetensors", f"{base_name}.model.safetensors"], | |
| 'clip': [f"{base_name}_clip.safetensors", f"{base_name}.clip.safetensors"], | |
| 'vae': [f"{base_name}_vae.safetensors", f"{base_name}.vae.safetensors"] | |
| } | |
| for component_type in components: | |
| component_dir = base_dir / component_type | |
| if component_dir.exists(): | |
| for pattern in patterns[component_type]: | |
| file_path = component_dir / pattern | |
| if file_path.exists(): | |
| components[component_type] = file_path | |
| break | |
| return components | |
| def find_all_base_names(input_dir: Path) -> Set[str]: | |
| """Find all unique base names from component files""" | |
| base_names = set() | |
| # Pattern to extract base name from component files | |
| pattern = re.compile(r'^(.+?)_(model|clip|vae)\.safetensors$') | |
| for component_type in ['model', 'clip', 'vae']: | |
| component_dir = input_dir / component_type | |
| if component_dir.exists(): | |
| for file_path in component_dir.glob('*.safetensors'): | |
| match = pattern.match(file_path.name) | |
| if match: | |
| base_names.add(match.group(1)) | |
| return base_names | |
| def load_tensors_from_file(file_path: Path) -> Dict[str, torch.Tensor]: | |
| """Load all tensors from a safetensor file""" | |
| tensors = {} | |
| print(f"Loading tensors from: {file_path}") | |
| with MemoryEfficientSafeOpen(str(file_path)) as f: | |
| keys = f.keys() | |
| print(f" Found {len(keys)} tensors") | |
| for key in keys: | |
| tensors[key] = f.get_tensor(key) | |
| return tensors | |
| def merge_components(component_files: Dict[str, Optional[Path]], | |
| output_path: Path, | |
| preserve_metadata: bool = True) -> bool: | |
| """Merge component files into a single safetensor file""" | |
| merged_tensors = {} | |
| merged_metadata = {} | |
| # Load tensors from each component file | |
| for component_type, file_path in component_files.items(): | |
| if file_path is None: | |
| print(f"Warning: No {component_type} file found") | |
| continue | |
| try: | |
| tensors = load_tensors_from_file(file_path) | |
| merged_tensors.update(tensors) | |
| # Extract metadata if requested (basic implementation) | |
| if preserve_metadata: | |
| merged_metadata[f"component_{component_type}"] = str(file_path.name) | |
| except Exception as e: | |
| print(f"Error loading {component_type} file {file_path}: {e}") | |
| return False | |
| if not merged_tensors: | |
| print("No tensors to merge!") | |
| return False | |
| # Save merged file | |
| try: | |
| metadata = merged_metadata if merged_metadata else None | |
| mem_eff_save_file(merged_tensors, str(output_path), metadata) | |
| print(f"Successfully merged {len(merged_tensors)} tensors") | |
| return True | |
| except Exception as e: | |
| print(f"Error saving merged file: {e}") | |
| return False | |
| def merge_specific_components(model_file: Optional[Path] = None, | |
| clip_file: Optional[Path] = None, | |
| vae_file: Optional[Path] = None, | |
| output_path: Path = None) -> bool: | |
| """Merge specific component files""" | |
| component_files = { | |
| 'model': model_file, | |
| 'clip': clip_file, | |
| 'vae': vae_file | |
| } | |
| # Filter out None values | |
| available_components = {k: v for k, v in component_files.items() if v is not None} | |
| if not available_components: | |
| print("No component files provided!") | |
| return False | |
| print(f"Merging components: {', '.join(available_components.keys())}") | |
| return merge_components(component_files, output_path) | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Merge separated CLIP, VAE, and model safetensor files back into single files" | |
| ) | |
| # Mode selection | |
| group = parser.add_mutually_exclusive_group(required=True) | |
| group.add_argument( | |
| "--auto", | |
| metavar="DIR", | |
| help="Auto-discover and merge all component sets in directory" | |
| ) | |
| group.add_argument( | |
| "--manual", | |
| action="store_true", | |
| help="Manually specify component files" | |
| ) | |
| # Auto mode options | |
| parser.add_argument( | |
| "-o", "--output-dir", | |
| default="merged_models", | |
| help="Output directory for merged files (auto mode, default: merged_models)" | |
| ) | |
| # Manual mode options | |
| parser.add_argument( | |
| "--model", | |
| type=Path, | |
| help="Path to model component file" | |
| ) | |
| parser.add_argument( | |
| "--clip", | |
| type=Path, | |
| help="Path to CLIP component file" | |
| ) | |
| parser.add_argument( | |
| "--vae", | |
| type=Path, | |
| help="Path to VAE component file" | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=Path, | |
| help="Output path for merged file (manual mode)" | |
| ) | |
| # Common options | |
| parser.add_argument( | |
| "--no-metadata", | |
| action="store_true", | |
| help="Don't preserve metadata in merged file" | |
| ) | |
| parser.add_argument( | |
| "--dry-run", | |
| action="store_true", | |
| help="Show what would be merged without doing it" | |
| ) | |
| args = parser.parse_args() | |
| if args.auto: | |
| # Auto mode - discover and merge all component sets | |
| input_dir = Path(args.auto) | |
| output_dir = Path(args.output_dir) | |
| if not input_dir.exists(): | |
| print(f"Error: Input directory {input_dir} does not exist") | |
| return | |
| # Find all base names | |
| base_names = find_all_base_names(input_dir) | |
| if not base_names: | |
| print("No component files found to merge") | |
| return | |
| print(f"Found {len(base_names)} model sets to merge: {', '.join(sorted(base_names))}") | |
| if not args.dry_run: | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| success_count = 0 | |
| for base_name in sorted(base_names): | |
| print(f"\n=== Processing {base_name} ===") | |
| component_files = find_component_files(input_dir, base_name) | |
| available = [k for k, v in component_files.items() if v is not None] | |
| missing = [k for k, v in component_files.items() if v is None] | |
| print(f"Available components: {', '.join(available)}") | |
| if missing: | |
| print(f"Missing components: {', '.join(missing)}") | |
| if not available: | |
| print("No components found, skipping") | |
| continue | |
| output_path = output_dir / f"{base_name}.safetensors" | |
| if args.dry_run: | |
| print(f"Would merge to: {output_path}") | |
| continue | |
| if merge_components(component_files, output_path, not args.no_metadata): | |
| success_count += 1 | |
| print(f"✓ Successfully merged {base_name}") | |
| else: | |
| print(f"✗ Failed to merge {base_name}") | |
| if not args.dry_run: | |
| print(f"\nMerging complete: {success_count}/{len(base_names)} successful") | |
| else: | |
| print(f"\nDry run complete: {len(base_names)} model sets found") | |
| else: | |
| # Manual mode - merge specific files | |
| if not args.output: | |
| print("Error: --output is required in manual mode") | |
| return | |
| component_files = [args.model, args.clip, args.vae] | |
| if not any(component_files): | |
| print("Error: At least one component file must be specified") | |
| return | |
| if args.dry_run: | |
| available = [] | |
| if args.model: available.append(f"model: {args.model}") | |
| if args.clip: available.append(f"clip: {args.clip}") | |
| if args.vae: available.append(f"vae: {args.vae}") | |
| print("Would merge:") | |
| for comp in available: | |
| print(f" {comp}") | |
| print(f"Output: {args.output}") | |
| return | |
| success = merge_specific_components( | |
| args.model, args.clip, args.vae, args.output | |
| ) | |
| if success: | |
| print(f"✓ Successfully merged to {args.output}") | |
| else: | |
| print("✗ Failed to merge components") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment