Created
August 22, 2025 03:23
-
-
Save marduk191/561a0f7b2fb0f3b2430841a703a1190a to your computer and use it in GitHub Desktop.
Robust_Pytorch_to_safetensors
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Convert PyTorch files (.bin, .pth, .pt, .ckpt) to SafeTensors format. | |
| By:marduk191 | |
| """ | |
| import os | |
| import argparse | |
| import torch | |
| from safetensors.torch import save_file, save_model | |
| from pathlib import Path | |
| import logging | |
| from collections import defaultdict | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Supported PyTorch file extensions | |
| PYTORCH_EXTENSIONS = {'.bin', '.pth', '.pt', '.ckpt'} | |
| def detect_shared_tensors(tensor_dict): | |
| """ | |
| Detect tensors that share memory (same data_ptr). | |
| Args: | |
| tensor_dict (dict): Dictionary of tensors | |
| Returns: | |
| dict: Mapping from data_ptr to list of tensor names that share that memory | |
| """ | |
| shared_tensors = defaultdict(list) | |
| for name, tensor in tensor_dict.items(): | |
| if isinstance(tensor, torch.Tensor): | |
| data_ptr = tensor.data_ptr() | |
| shared_tensors[data_ptr].append(name) | |
| # Only return groups with more than one tensor | |
| return {ptr: names for ptr, names in shared_tensors.items() if len(names) > 1} | |
| def convert_pytorch_to_safetensors(input_path, output_path=None, force=False, use_save_model=True): | |
| """ | |
| Convert a PyTorch file (.bin, .pth, .pt, .ckpt) to SafeTensors format. | |
| Args: | |
| input_path (str): Path to the input PyTorch file | |
| output_path (str, optional): Path for the output .safetensors file | |
| force (bool): Whether to overwrite existing output files | |
| use_save_model (bool): Whether to use save_model for proper shared tensor handling | |
| Returns: | |
| str: Path to the created SafeTensors file | |
| """ | |
| input_path = Path(input_path) | |
| if not input_path.exists(): | |
| raise FileNotFoundError(f"Input file not found: {input_path}") | |
| if input_path.suffix.lower() not in PYTORCH_EXTENSIONS: | |
| raise ValueError(f"Input file must have one of these extensions: {', '.join(PYTORCH_EXTENSIONS)}, got: {input_path.suffix}") | |
| # Generate output path if not provided | |
| if output_path is None: | |
| output_path = input_path.with_suffix('.safetensors') | |
| else: | |
| output_path = Path(output_path) | |
| # Check if output file already exists | |
| if output_path.exists() and not force: | |
| raise FileExistsError(f"Output file already exists: {output_path}. Use --force to overwrite.") | |
| try: | |
| logger.info(f"Loading PyTorch file: {input_path}") | |
| # Load the PyTorch file | |
| # Use map_location='cpu' to avoid GPU memory issues | |
| state_dict = torch.load(input_path, map_location='cpu') | |
| # Handle different possible formats | |
| if isinstance(state_dict, dict): | |
| # Check if it's a full checkpoint with model state | |
| if 'state_dict' in state_dict: | |
| tensors = state_dict['state_dict'] | |
| logger.info("Found 'state_dict' key in checkpoint") | |
| elif 'model' in state_dict: | |
| tensors = state_dict['model'] | |
| logger.info("Found 'model' key in checkpoint") | |
| elif 'model_state_dict' in state_dict: | |
| tensors = state_dict['model_state_dict'] | |
| logger.info("Found 'model_state_dict' key in checkpoint") | |
| else: | |
| # Assume the dict itself contains the tensors | |
| tensors = state_dict | |
| logger.info("Using entire dict as tensor data") | |
| else: | |
| raise ValueError(f"Unexpected data type in PyTorch file: {type(state_dict)}") | |
| # Verify all values are tensors | |
| for key, value in tensors.items(): | |
| if not isinstance(value, torch.Tensor): | |
| logger.warning(f"Skipping non-tensor key '{key}' with type {type(value)}") | |
| continue | |
| # Filter to only include tensors | |
| tensor_dict = {k: v for k, v in tensors.items() if isinstance(v, torch.Tensor)} | |
| if not tensor_dict: | |
| raise ValueError("No tensors found in the input file") | |
| logger.info(f"Found {len(tensor_dict)} tensors to convert") | |
| # Check for shared tensors | |
| shared_groups = detect_shared_tensors(tensor_dict) | |
| if shared_groups: | |
| logger.warning("Detected tensors that share memory:") | |
| for data_ptr, names in shared_groups.items(): | |
| logger.warning(f" Shared memory group: {names}") | |
| if use_save_model: | |
| logger.info("Using save_model to handle shared tensors properly") | |
| else: | |
| logger.warning("Using save_file - this may cause memory duplication!") | |
| # Save as SafeTensors | |
| logger.info(f"Saving SafeTensors file: {output_path}") | |
| if use_save_model and shared_groups: | |
| # Create a simple model-like object to use with save_model | |
| class SimpleModel: | |
| def __init__(self, state_dict): | |
| for key, value in state_dict.items(): | |
| setattr(self, key.replace('.', '_'), value) | |
| def state_dict(self): | |
| return tensor_dict | |
| model = SimpleModel(tensor_dict) | |
| save_model(model, output_path) | |
| else: | |
| # Use standard save_file | |
| save_file(tensor_dict, output_path) | |
| # Compare file sizes | |
| input_size = input_path.stat().st_size / (1024 * 1024) # MB | |
| output_size = output_path.stat().st_size / (1024 * 1024) # MB | |
| logger.info(f"Conversion completed successfully!") | |
| logger.info(f"Input size: {input_size:.2f} MB") | |
| logger.info(f"Output size: {output_size:.2f} MB") | |
| logger.info(f"Size difference: {output_size - input_size:.2f} MB") | |
| return str(output_path) | |
| except Exception as e: | |
| logger.error(f"Error during conversion: {str(e)}") | |
| raise | |
| def convert_directory(input_dir, output_dir=None, force=False, recursive=False, use_save_model=True): | |
| """ | |
| Convert all PyTorch files (.bin, .pth, .pt, .ckpt) in a directory to SafeTensors format. | |
| Args: | |
| input_dir (str): Directory containing PyTorch files | |
| output_dir (str, optional): Output directory (defaults to input_dir) | |
| force (bool): Whether to overwrite existing files | |
| recursive (bool): Whether to search subdirectories | |
| use_save_model (bool): Whether to use save_model for proper shared tensor handling | |
| """ | |
| input_dir = Path(input_dir) | |
| if not input_dir.exists() or not input_dir.is_dir(): | |
| raise ValueError(f"Input directory does not exist: {input_dir}") | |
| if output_dir is None: | |
| output_dir = input_dir | |
| else: | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Find all PyTorch files | |
| pytorch_files = [] | |
| for ext in PYTORCH_EXTENSIONS: | |
| if recursive: | |
| pytorch_files.extend(list(input_dir.rglob(f"*{ext}"))) | |
| else: | |
| pytorch_files.extend(list(input_dir.glob(f"*{ext}"))) | |
| if not pytorch_files: | |
| logger.info(f"No PyTorch files found in {input_dir}") | |
| return | |
| logger.info(f"Found {len(pytorch_files)} PyTorch files to convert") | |
| converted_count = 0 | |
| for pytorch_file in pytorch_files: | |
| try: | |
| # Maintain directory structure if using recursive mode | |
| if recursive and output_dir != input_dir: | |
| rel_path = pytorch_file.relative_to(input_dir) | |
| output_file = output_dir / rel_path.with_suffix('.safetensors') | |
| output_file.parent.mkdir(parents=True, exist_ok=True) | |
| else: | |
| output_file = output_dir / pytorch_file.with_suffix('.safetensors').name | |
| convert_pytorch_to_safetensors(pytorch_file, output_file, force=force, use_save_model=use_save_model) | |
| converted_count += 1 | |
| except Exception as e: | |
| logger.error(f"Failed to convert {pytorch_file}: {str(e)}") | |
| continue | |
| logger.info(f"Successfully converted {converted_count}/{len(pytorch_files)} files") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Convert PyTorch files (.bin, .pth, .pt, .ckpt) to SafeTensors format", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| # Convert a single file | |
| python convert_to_safetensors.py model.pth | |
| python convert_to_safetensors.py checkpoint.ckpt | |
| python convert_to_safetensors.py weights.bin | |
| # Convert with custom output name | |
| python convert_to_safetensors.py model.pt -o model_safe.safetensors | |
| # Convert with save_file method (may duplicate shared tensors) | |
| python convert_to_safetensors.py model.pt --use-save-file | |
| # Convert all PyTorch files in a directory | |
| python convert_to_safetensors.py /path/to/models/ --directory | |
| # Convert recursively with force overwrite | |
| python convert_to_safetensors.py /path/to/models/ --directory --recursive --force | |
| """ | |
| ) | |
| parser.add_argument("input", help="Input PyTorch file (.bin, .pth, .pt, .ckpt) or directory") | |
| parser.add_argument("-o", "--output", help="Output file or directory path") | |
| parser.add_argument("-d", "--directory", action="store_true", | |
| help="Process all PyTorch files in the input directory") | |
| parser.add_argument("-r", "--recursive", action="store_true", | |
| help="Search subdirectories recursively (only with --directory)") | |
| parser.add_argument("-f", "--force", action="store_true", | |
| help="Overwrite existing output files") | |
| parser.add_argument("--use-save-file", action="store_true", | |
| help="Use save_file instead of save_model (may duplicate shared tensors)") | |
| parser.add_argument("-v", "--verbose", action="store_true", | |
| help="Enable verbose logging") | |
| args = parser.parse_args() | |
| if args.verbose: | |
| logging.getLogger().setLevel(logging.DEBUG) | |
| # Determine whether to use save_model (default) or save_file | |
| use_save_model = not args.use_save_file | |
| try: | |
| if args.directory: | |
| convert_directory(args.input, args.output, args.force, args.recursive, use_save_model) | |
| else: | |
| convert_pytorch_to_safetensors(args.input, args.output, args.force, use_save_model) | |
| except Exception as e: | |
| logger.error(f"Conversion failed: {str(e)}") | |
| return 1 | |
| return 0 | |
| if __name__ == "__main__": | |
| exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment