Skip to content

Instantly share code, notes, and snippets.

@marduk191
Last active July 21, 2025 00:49
Show Gist options
  • Select an option

  • Save marduk191/dd38e53f5395dcb6b01dfb576e15154c to your computer and use it in GitHub Desktop.

Select an option

Save marduk191/dd38e53f5395dcb6b01dfb576e15154c to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
Script to scan directory for safetensor files and extract model, clip, and vae components
into separate folders using memory-efficient loading.
https://github.com/marduk191
"""
import os
import json
import struct
import torch
from pathlib import Path
from typing import Dict, Any, List, Tuple
import argparse
import shutil
from collections import defaultdict
class MemoryEfficientSafeOpen:
# does not support metadata loading
def __init__(self, filename):
self.filename = filename
self.header, self.header_size = self._read_header()
self.file = open(filename, "rb")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def keys(self):
return [k for k in self.header.keys() if k != "__metadata__"]
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
metadata = self.header[key]
offset_start, offset_end = metadata["data_offsets"]
if offset_start == offset_end:
tensor_bytes = None
else:
# adjust offset by header size
self.file.seek(self.header_size + 8 + offset_start)
tensor_bytes = self.file.read(offset_end - offset_start)
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
with open(self.filename, "rb") as f:
header_size = struct.unpack("<Q", f.read(8))[0]
header_json = f.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype = self._get_torch_dtype(metadata["dtype"])
shape = metadata["shape"]
if tensor_bytes is None:
byte_tensor = torch.empty(0, dtype=torch.uint8)
else:
tensor_bytes = bytearray(tensor_bytes) # make it writable
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
# process float8 types
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
# convert to the target dtype and reshape
return byte_tensor.view(dtype).reshape(shape)
@staticmethod
def _get_torch_dtype(dtype_str):
dtype_map = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
"I32": torch.int32,
"I16": torch.int16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
}
# add float8 types if available
if hasattr(torch, "float8_e5m2"):
dtype_map["F8_E5M2"] = torch.float8_e5m2
if hasattr(torch, "float8_e4m3fn"):
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
return dtype_map.get(dtype_str)
@staticmethod
def _convert_float8(byte_tensor, dtype_str, shape):
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
else:
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
"""Memory efficient save file"""
_TYPES = {
torch.float64: "F64",
torch.float32: "F32",
torch.float16: "F16",
torch.bfloat16: "BF16",
torch.int64: "I64",
torch.int32: "I32",
torch.int16: "I16",
torch.int8: "I8",
torch.uint8: "U8",
torch.bool: "BOOL",
getattr(torch, "float8_e5m2", None): "F8_E5M2",
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
}
_ALIGN = 256
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
validated = {}
for key, value in metadata.items():
if not isinstance(key, str):
raise ValueError(f"Metadata key must be a string, got {type(key)}")
if not isinstance(value, str):
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
validated[key] = str(value)
else:
validated[key] = value
return validated
print(f"Saving: {filename}")
header = {}
offset = 0
if metadata:
header["__metadata__"] = validate_metadata(metadata)
for k, v in tensors.items():
if v.numel() == 0: # empty tensor
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
else:
size = v.numel() * v.element_size()
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
offset += size
hjson = json.dumps(header).encode("utf-8")
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
with open(filename, "wb") as f:
f.write(struct.pack("<Q", len(hjson)))
f.write(hjson)
for k, v in tensors.items():
if v.numel() == 0:
continue
if v.is_cuda:
# Direct GPU to disk save
with torch.cuda.device(v.device):
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
tensor_bytes = v.contiguous().view(torch.uint8)
tensor_bytes.cpu().numpy().tofile(f)
else:
# CPU tensor save
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
v.contiguous().view(torch.uint8).numpy().tofile(f)
def classify_tensor_key(key: str) -> str:
"""Classify tensor key into model, clip, or vae component"""
key_lower = key.lower()
# CLIP text encoder patterns
clip_patterns = [
'text_model', 'encoder.layers', 'text_encoder', 'clip',
'transformer.text_model', 'cond_stage_model', 'text_projection',
'token_embedding', 'position_embedding', 'ln_final'
]
# VAE patterns
vae_patterns = [
'first_stage_model', 'vae', 'decoder', 'encoder.conv_in',
'encoder.conv_out', 'encoder.down', 'decoder.conv_in',
'decoder.conv_out', 'decoder.up', 'quant_conv', 'post_quant_conv'
]
# Check for CLIP
for pattern in clip_patterns:
if pattern in key_lower:
return 'clip'
# Check for VAE
for pattern in vae_patterns:
if pattern in key_lower:
return 'vae'
# Everything else is model (UNet, etc.)
return 'model'
def find_safetensor_files(directory: str, recursive: bool = True) -> List[Path]:
"""Find all safetensor files in directory"""
directory = Path(directory)
if recursive:
return list(directory.rglob("*.safetensors"))
else:
return list(directory.glob("*.safetensors"))
def extract_components(safetensor_path: Path, output_dir: Path, dry_run: bool = False):
"""Extract model, clip, and vae components from a safetensor file"""
print(f"\nProcessing: {safetensor_path}")
# Create component dictionaries
components = {
'model': {},
'clip': {},
'vae': {}
}
try:
with MemoryEfficientSafeOpen(str(safetensor_path)) as f:
keys = f.keys()
print(f"Found {len(keys)} tensors")
# Classify and load tensors
for key in keys:
component_type = classify_tensor_key(key)
if not dry_run:
tensor = f.get_tensor(key)
components[component_type][key] = tensor
else:
components[component_type][key] = None
# Print classification summary
print(f" Model tensors: {len(components['model'])}")
print(f" CLIP tensors: {len(components['clip'])}")
print(f" VAE tensors: {len(components['vae'])}")
if dry_run:
return
# Create output directories
base_name = safetensor_path.stem
for component_type, tensors in components.items():
if not tensors: # Skip empty components
continue
component_dir = output_dir / component_type
component_dir.mkdir(parents=True, exist_ok=True)
output_file = component_dir / f"{base_name}_{component_type}.safetensors"
mem_eff_save_file(tensors, str(output_file))
print(f" Saved {component_type}: {output_file}")
except Exception as e:
print(f"Error processing {safetensor_path}: {e}")
def main():
parser = argparse.ArgumentParser(
description="Extract model, clip, and vae components from safetensor files"
)
parser.add_argument(
"input_dir",
help="Directory to scan for safetensor files"
)
parser.add_argument(
"-o", "--output",
default="extracted_components",
help="Output directory for extracted components (default: extracted_components)"
)
parser.add_argument(
"--recursive",
action="store_true",
help="Recursively scan subdirectories"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be extracted without actually doing it"
)
parser.add_argument(
"--pattern",
help="Only process files matching this pattern (e.g., '*diffusion*')"
)
args = parser.parse_args()
input_dir = Path(args.input_dir)
output_dir = Path(args.output)
if not input_dir.exists():
print(f"Error: Input directory {input_dir} does not exist")
return
# Find safetensor files
safetensor_files = find_safetensor_files(input_dir, args.recursive)
if args.pattern:
from fnmatch import fnmatch
safetensor_files = [f for f in safetensor_files if fnmatch(f.name, args.pattern)]
if not safetensor_files:
print("No safetensor files found")
return
print(f"Found {len(safetensor_files)} safetensor files")
if args.dry_run:
print("\n=== DRY RUN MODE ===")
if not args.dry_run:
output_dir.mkdir(parents=True, exist_ok=True)
# Process each file
for safetensor_file in safetensor_files:
extract_components(safetensor_file, output_dir, args.dry_run)
if not args.dry_run:
print(f"\nExtraction complete. Components saved to: {output_dir}")
else:
print(f"\nDry run complete. Would save components to: {output_dir}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment