Last active
July 21, 2025 00:49
-
-
Save marduk191/dd38e53f5395dcb6b01dfb576e15154c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Script to scan directory for safetensor files and extract model, clip, and vae components | |
| into separate folders using memory-efficient loading. | |
| https://github.com/marduk191 | |
| """ | |
| import os | |
| import json | |
| import struct | |
| import torch | |
| from pathlib import Path | |
| from typing import Dict, Any, List, Tuple | |
| import argparse | |
| import shutil | |
| from collections import defaultdict | |
| class MemoryEfficientSafeOpen: | |
| # does not support metadata loading | |
| def __init__(self, filename): | |
| self.filename = filename | |
| self.header, self.header_size = self._read_header() | |
| self.file = open(filename, "rb") | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.file.close() | |
| def keys(self): | |
| return [k for k in self.header.keys() if k != "__metadata__"] | |
| def get_tensor(self, key): | |
| if key not in self.header: | |
| raise KeyError(f"Tensor '{key}' not found in the file") | |
| metadata = self.header[key] | |
| offset_start, offset_end = metadata["data_offsets"] | |
| if offset_start == offset_end: | |
| tensor_bytes = None | |
| else: | |
| # adjust offset by header size | |
| self.file.seek(self.header_size + 8 + offset_start) | |
| tensor_bytes = self.file.read(offset_end - offset_start) | |
| return self._deserialize_tensor(tensor_bytes, metadata) | |
| def _read_header(self): | |
| with open(self.filename, "rb") as f: | |
| header_size = struct.unpack("<Q", f.read(8))[0] | |
| header_json = f.read(header_size).decode("utf-8") | |
| return json.loads(header_json), header_size | |
| def _deserialize_tensor(self, tensor_bytes, metadata): | |
| dtype = self._get_torch_dtype(metadata["dtype"]) | |
| shape = metadata["shape"] | |
| if tensor_bytes is None: | |
| byte_tensor = torch.empty(0, dtype=torch.uint8) | |
| else: | |
| tensor_bytes = bytearray(tensor_bytes) # make it writable | |
| byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8) | |
| # process float8 types | |
| if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]: | |
| return self._convert_float8(byte_tensor, metadata["dtype"], shape) | |
| # convert to the target dtype and reshape | |
| return byte_tensor.view(dtype).reshape(shape) | |
| @staticmethod | |
| def _get_torch_dtype(dtype_str): | |
| dtype_map = { | |
| "F64": torch.float64, | |
| "F32": torch.float32, | |
| "F16": torch.float16, | |
| "BF16": torch.bfloat16, | |
| "I64": torch.int64, | |
| "I32": torch.int32, | |
| "I16": torch.int16, | |
| "I8": torch.int8, | |
| "U8": torch.uint8, | |
| "BOOL": torch.bool, | |
| } | |
| # add float8 types if available | |
| if hasattr(torch, "float8_e5m2"): | |
| dtype_map["F8_E5M2"] = torch.float8_e5m2 | |
| if hasattr(torch, "float8_e4m3fn"): | |
| dtype_map["F8_E4M3"] = torch.float8_e4m3fn | |
| return dtype_map.get(dtype_str) | |
| @staticmethod | |
| def _convert_float8(byte_tensor, dtype_str, shape): | |
| if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"): | |
| return byte_tensor.view(torch.float8_e5m2).reshape(shape) | |
| elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"): | |
| return byte_tensor.view(torch.float8_e4m3fn).reshape(shape) | |
| else: | |
| raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") | |
| def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): | |
| """Memory efficient save file""" | |
| _TYPES = { | |
| torch.float64: "F64", | |
| torch.float32: "F32", | |
| torch.float16: "F16", | |
| torch.bfloat16: "BF16", | |
| torch.int64: "I64", | |
| torch.int32: "I32", | |
| torch.int16: "I16", | |
| torch.int8: "I8", | |
| torch.uint8: "U8", | |
| torch.bool: "BOOL", | |
| getattr(torch, "float8_e5m2", None): "F8_E5M2", | |
| getattr(torch, "float8_e4m3fn", None): "F8_E4M3", | |
| } | |
| _ALIGN = 256 | |
| def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: | |
| validated = {} | |
| for key, value in metadata.items(): | |
| if not isinstance(key, str): | |
| raise ValueError(f"Metadata key must be a string, got {type(key)}") | |
| if not isinstance(value, str): | |
| print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") | |
| validated[key] = str(value) | |
| else: | |
| validated[key] = value | |
| return validated | |
| print(f"Saving: {filename}") | |
| header = {} | |
| offset = 0 | |
| if metadata: | |
| header["__metadata__"] = validate_metadata(metadata) | |
| for k, v in tensors.items(): | |
| if v.numel() == 0: # empty tensor | |
| header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} | |
| else: | |
| size = v.numel() * v.element_size() | |
| header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} | |
| offset += size | |
| hjson = json.dumps(header).encode("utf-8") | |
| hjson += b" " * (-(len(hjson) + 8) % _ALIGN) | |
| with open(filename, "wb") as f: | |
| f.write(struct.pack("<Q", len(hjson))) | |
| f.write(hjson) | |
| for k, v in tensors.items(): | |
| if v.numel() == 0: | |
| continue | |
| if v.is_cuda: | |
| # Direct GPU to disk save | |
| with torch.cuda.device(v.device): | |
| if v.dim() == 0: # if scalar, need to add a dimension to work with view | |
| v = v.unsqueeze(0) | |
| tensor_bytes = v.contiguous().view(torch.uint8) | |
| tensor_bytes.cpu().numpy().tofile(f) | |
| else: | |
| # CPU tensor save | |
| if v.dim() == 0: # if scalar, need to add a dimension to work with view | |
| v = v.unsqueeze(0) | |
| v.contiguous().view(torch.uint8).numpy().tofile(f) | |
| def classify_tensor_key(key: str) -> str: | |
| """Classify tensor key into model, clip, or vae component""" | |
| key_lower = key.lower() | |
| # CLIP text encoder patterns | |
| clip_patterns = [ | |
| 'text_model', 'encoder.layers', 'text_encoder', 'clip', | |
| 'transformer.text_model', 'cond_stage_model', 'text_projection', | |
| 'token_embedding', 'position_embedding', 'ln_final' | |
| ] | |
| # VAE patterns | |
| vae_patterns = [ | |
| 'first_stage_model', 'vae', 'decoder', 'encoder.conv_in', | |
| 'encoder.conv_out', 'encoder.down', 'decoder.conv_in', | |
| 'decoder.conv_out', 'decoder.up', 'quant_conv', 'post_quant_conv' | |
| ] | |
| # Check for CLIP | |
| for pattern in clip_patterns: | |
| if pattern in key_lower: | |
| return 'clip' | |
| # Check for VAE | |
| for pattern in vae_patterns: | |
| if pattern in key_lower: | |
| return 'vae' | |
| # Everything else is model (UNet, etc.) | |
| return 'model' | |
| def find_safetensor_files(directory: str, recursive: bool = True) -> List[Path]: | |
| """Find all safetensor files in directory""" | |
| directory = Path(directory) | |
| if recursive: | |
| return list(directory.rglob("*.safetensors")) | |
| else: | |
| return list(directory.glob("*.safetensors")) | |
| def extract_components(safetensor_path: Path, output_dir: Path, dry_run: bool = False): | |
| """Extract model, clip, and vae components from a safetensor file""" | |
| print(f"\nProcessing: {safetensor_path}") | |
| # Create component dictionaries | |
| components = { | |
| 'model': {}, | |
| 'clip': {}, | |
| 'vae': {} | |
| } | |
| try: | |
| with MemoryEfficientSafeOpen(str(safetensor_path)) as f: | |
| keys = f.keys() | |
| print(f"Found {len(keys)} tensors") | |
| # Classify and load tensors | |
| for key in keys: | |
| component_type = classify_tensor_key(key) | |
| if not dry_run: | |
| tensor = f.get_tensor(key) | |
| components[component_type][key] = tensor | |
| else: | |
| components[component_type][key] = None | |
| # Print classification summary | |
| print(f" Model tensors: {len(components['model'])}") | |
| print(f" CLIP tensors: {len(components['clip'])}") | |
| print(f" VAE tensors: {len(components['vae'])}") | |
| if dry_run: | |
| return | |
| # Create output directories | |
| base_name = safetensor_path.stem | |
| for component_type, tensors in components.items(): | |
| if not tensors: # Skip empty components | |
| continue | |
| component_dir = output_dir / component_type | |
| component_dir.mkdir(parents=True, exist_ok=True) | |
| output_file = component_dir / f"{base_name}_{component_type}.safetensors" | |
| mem_eff_save_file(tensors, str(output_file)) | |
| print(f" Saved {component_type}: {output_file}") | |
| except Exception as e: | |
| print(f"Error processing {safetensor_path}: {e}") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Extract model, clip, and vae components from safetensor files" | |
| ) | |
| parser.add_argument( | |
| "input_dir", | |
| help="Directory to scan for safetensor files" | |
| ) | |
| parser.add_argument( | |
| "-o", "--output", | |
| default="extracted_components", | |
| help="Output directory for extracted components (default: extracted_components)" | |
| ) | |
| parser.add_argument( | |
| "--recursive", | |
| action="store_true", | |
| help="Recursively scan subdirectories" | |
| ) | |
| parser.add_argument( | |
| "--dry-run", | |
| action="store_true", | |
| help="Show what would be extracted without actually doing it" | |
| ) | |
| parser.add_argument( | |
| "--pattern", | |
| help="Only process files matching this pattern (e.g., '*diffusion*')" | |
| ) | |
| args = parser.parse_args() | |
| input_dir = Path(args.input_dir) | |
| output_dir = Path(args.output) | |
| if not input_dir.exists(): | |
| print(f"Error: Input directory {input_dir} does not exist") | |
| return | |
| # Find safetensor files | |
| safetensor_files = find_safetensor_files(input_dir, args.recursive) | |
| if args.pattern: | |
| from fnmatch import fnmatch | |
| safetensor_files = [f for f in safetensor_files if fnmatch(f.name, args.pattern)] | |
| if not safetensor_files: | |
| print("No safetensor files found") | |
| return | |
| print(f"Found {len(safetensor_files)} safetensor files") | |
| if args.dry_run: | |
| print("\n=== DRY RUN MODE ===") | |
| if not args.dry_run: | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Process each file | |
| for safetensor_file in safetensor_files: | |
| extract_components(safetensor_file, output_dir, args.dry_run) | |
| if not args.dry_run: | |
| print(f"\nExtraction complete. Components saved to: {output_dir}") | |
| else: | |
| print(f"\nDry run complete. Would save components to: {output_dir}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment