Skip to content

Instantly share code, notes, and snippets.

@marduk191
Created July 23, 2025 19:13
Show Gist options
  • Select an option

  • Save marduk191/174998c11451e10e9e94356aeb1e2961 to your computer and use it in GitHub Desktop.

Select an option

Save marduk191/174998c11451e10e9e94356aeb1e2961 to your computer and use it in GitHub Desktop.
combine multiple clips and vae with the model for packaging
#!/usr/bin/env python3
"""
Script to merge separated CLIP, VAE, and model safetensor files back into a single safetensor file.
https://github.com/marduk191
"""
import os
import json
import struct
import torch
from pathlib import Path
from typing import Dict, Any, List, Optional, Set
import argparse
from collections import defaultdict
import re
class MemoryEfficientSafeOpen:
# does not support metadata loading
def __init__(self, filename):
self.filename = filename
self.header, self.header_size = self._read_header()
self.file = open(filename, "rb")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def keys(self):
return [k for k in self.header.keys() if k != "__metadata__"]
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
metadata = self.header[key]
offset_start, offset_end = metadata["data_offsets"]
if offset_start == offset_end:
tensor_bytes = None
else:
# adjust offset by header size
self.file.seek(self.header_size + 8 + offset_start)
tensor_bytes = self.file.read(offset_end - offset_start)
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
with open(self.filename, "rb") as f:
header_size = struct.unpack("<Q", f.read(8))[0]
header_json = f.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype = self._get_torch_dtype(metadata["dtype"])
shape = metadata["shape"]
if tensor_bytes is None:
byte_tensor = torch.empty(0, dtype=torch.uint8)
else:
tensor_bytes = bytearray(tensor_bytes) # make it writable
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
# process float8 types
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
# convert to the target dtype and reshape
return byte_tensor.view(dtype).reshape(shape)
@staticmethod
def _get_torch_dtype(dtype_str):
dtype_map = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
"I32": torch.int32,
"I16": torch.int16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
}
# add float8 types if available
if hasattr(torch, "float8_e5m2"):
dtype_map["F8_E5M2"] = torch.float8_e5m2
if hasattr(torch, "float8_e4m3fn"):
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
return dtype_map.get(dtype_str)
@staticmethod
def _convert_float8(byte_tensor, dtype_str, shape):
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
else:
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
"""Memory efficient save file"""
_TYPES = {
torch.float64: "F64",
torch.float32: "F32",
torch.float16: "F16",
torch.bfloat16: "BF16",
torch.int64: "I64",
torch.int32: "I32",
torch.int16: "I16",
torch.int8: "I8",
torch.uint8: "U8",
torch.bool: "BOOL",
getattr(torch, "float8_e5m2", None): "F8_E5M2",
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
}
_ALIGN = 256
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
validated = {}
for key, value in metadata.items():
if not isinstance(key, str):
raise ValueError(f"Metadata key must be a string, got {type(key)}")
if not isinstance(value, str):
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
validated[key] = str(value)
else:
validated[key] = value
return validated
print(f"Saving merged file: {filename}")
header = {}
offset = 0
if metadata:
header["__metadata__"] = validate_metadata(metadata)
for k, v in tensors.items():
if v.numel() == 0: # empty tensor
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
else:
size = v.numel() * v.element_size()
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
offset += size
hjson = json.dumps(header).encode("utf-8")
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
with open(filename, "wb") as f:
f.write(struct.pack("<Q", len(hjson)))
f.write(hjson)
for k, v in tensors.items():
if v.numel() == 0:
continue
if v.is_cuda:
# Direct GPU to disk save
with torch.cuda.device(v.device):
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
tensor_bytes = v.contiguous().view(torch.uint8)
tensor_bytes.cpu().numpy().tofile(f)
else:
# CPU tensor save
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
v.contiguous().view(torch.uint8).numpy().tofile(f)
def find_component_files(base_dir: Path, base_name: str) -> Dict[str, Optional[Path]]:
"""Find model, clip, and vae component files for a given base name"""
components = {
'model': None,
'clip': None,
'vae': None
}
# Search patterns for each component
patterns = {
'model': [f"{base_name}_model.safetensors", f"{base_name}.model.safetensors"],
'clip': [f"{base_name}_clip.safetensors", f"{base_name}.clip.safetensors"],
'vae': [f"{base_name}_vae.safetensors", f"{base_name}.vae.safetensors"]
}
for component_type in components:
component_dir = base_dir / component_type
if component_dir.exists():
for pattern in patterns[component_type]:
file_path = component_dir / pattern
if file_path.exists():
components[component_type] = file_path
break
return components
def find_all_base_names(input_dir: Path) -> Set[str]:
"""Find all unique base names from component files"""
base_names = set()
# Pattern to extract base name from component files
pattern = re.compile(r'^(.+?)_(model|clip|vae)\.safetensors$')
for component_type in ['model', 'clip', 'vae']:
component_dir = input_dir / component_type
if component_dir.exists():
for file_path in component_dir.glob('*.safetensors'):
match = pattern.match(file_path.name)
if match:
base_names.add(match.group(1))
return base_names
def load_tensors_from_file(file_path: Path) -> Dict[str, torch.Tensor]:
"""Load all tensors from a safetensor file"""
tensors = {}
print(f"Loading tensors from: {file_path}")
with MemoryEfficientSafeOpen(str(file_path)) as f:
keys = f.keys()
print(f" Found {len(keys)} tensors")
for key in keys:
tensors[key] = f.get_tensor(key)
return tensors
def merge_components(component_files: Dict[str, Optional[Path]],
output_path: Path,
preserve_metadata: bool = True) -> bool:
"""Merge component files into a single safetensor file"""
merged_tensors = {}
merged_metadata = {}
# Load tensors from each component file
for component_type, file_path in component_files.items():
if file_path is None:
print(f"Warning: No {component_type} file found")
continue
try:
tensors = load_tensors_from_file(file_path)
merged_tensors.update(tensors)
# Extract metadata if requested (basic implementation)
if preserve_metadata:
merged_metadata[f"component_{component_type}"] = str(file_path.name)
except Exception as e:
print(f"Error loading {component_type} file {file_path}: {e}")
return False
if not merged_tensors:
print("No tensors to merge!")
return False
# Save merged file
try:
metadata = merged_metadata if merged_metadata else None
mem_eff_save_file(merged_tensors, str(output_path), metadata)
print(f"Successfully merged {len(merged_tensors)} tensors")
return True
except Exception as e:
print(f"Error saving merged file: {e}")
return False
def merge_specific_components(model_file: Optional[Path] = None,
clip_file: Optional[Path] = None,
vae_file: Optional[Path] = None,
output_path: Path = None) -> bool:
"""Merge specific component files"""
component_files = {
'model': model_file,
'clip': clip_file,
'vae': vae_file
}
# Filter out None values
available_components = {k: v for k, v in component_files.items() if v is not None}
if not available_components:
print("No component files provided!")
return False
print(f"Merging components: {', '.join(available_components.keys())}")
return merge_components(component_files, output_path)
def main():
parser = argparse.ArgumentParser(
description="Merge separated CLIP, VAE, and model safetensor files back into single files"
)
# Mode selection
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"--auto",
metavar="DIR",
help="Auto-discover and merge all component sets in directory"
)
group.add_argument(
"--manual",
action="store_true",
help="Manually specify component files"
)
# Auto mode options
parser.add_argument(
"-o", "--output-dir",
default="merged_models",
help="Output directory for merged files (auto mode, default: merged_models)"
)
# Manual mode options
parser.add_argument(
"--model",
type=Path,
help="Path to model component file"
)
parser.add_argument(
"--clip",
type=Path,
help="Path to CLIP component file"
)
parser.add_argument(
"--vae",
type=Path,
help="Path to VAE component file"
)
parser.add_argument(
"--output",
type=Path,
help="Output path for merged file (manual mode)"
)
# Common options
parser.add_argument(
"--no-metadata",
action="store_true",
help="Don't preserve metadata in merged file"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be merged without doing it"
)
args = parser.parse_args()
if args.auto:
# Auto mode - discover and merge all component sets
input_dir = Path(args.auto)
output_dir = Path(args.output_dir)
if not input_dir.exists():
print(f"Error: Input directory {input_dir} does not exist")
return
# Find all base names
base_names = find_all_base_names(input_dir)
if not base_names:
print("No component files found to merge")
return
print(f"Found {len(base_names)} model sets to merge: {', '.join(sorted(base_names))}")
if not args.dry_run:
output_dir.mkdir(parents=True, exist_ok=True)
success_count = 0
for base_name in sorted(base_names):
print(f"\n=== Processing {base_name} ===")
component_files = find_component_files(input_dir, base_name)
available = [k for k, v in component_files.items() if v is not None]
missing = [k for k, v in component_files.items() if v is None]
print(f"Available components: {', '.join(available)}")
if missing:
print(f"Missing components: {', '.join(missing)}")
if not available:
print("No components found, skipping")
continue
output_path = output_dir / f"{base_name}.safetensors"
if args.dry_run:
print(f"Would merge to: {output_path}")
continue
if merge_components(component_files, output_path, not args.no_metadata):
success_count += 1
print(f"✓ Successfully merged {base_name}")
else:
print(f"✗ Failed to merge {base_name}")
if not args.dry_run:
print(f"\nMerging complete: {success_count}/{len(base_names)} successful")
else:
print(f"\nDry run complete: {len(base_names)} model sets found")
else:
# Manual mode - merge specific files
if not args.output:
print("Error: --output is required in manual mode")
return
component_files = [args.model, args.clip, args.vae]
if not any(component_files):
print("Error: At least one component file must be specified")
return
if args.dry_run:
available = []
if args.model: available.append(f"model: {args.model}")
if args.clip: available.append(f"clip: {args.clip}")
if args.vae: available.append(f"vae: {args.vae}")
print("Would merge:")
for comp in available:
print(f" {comp}")
print(f"Output: {args.output}")
return
success = merge_specific_components(
args.model, args.clip, args.vae, args.output
)
if success:
print(f"✓ Successfully merged to {args.output}")
else:
print("✗ Failed to merge components")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment