Skip to content

Instantly share code, notes, and snippets.

@marduk191
Created August 22, 2025 03:23
Show Gist options
  • Select an option

  • Save marduk191/561a0f7b2fb0f3b2430841a703a1190a to your computer and use it in GitHub Desktop.

Select an option

Save marduk191/561a0f7b2fb0f3b2430841a703a1190a to your computer and use it in GitHub Desktop.
Robust_Pytorch_to_safetensors
#!/usr/bin/env python3
"""
Convert PyTorch files (.bin, .pth, .pt, .ckpt) to SafeTensors format.
By:marduk191
"""
import os
import argparse
import torch
from safetensors.torch import save_file, save_model
from pathlib import Path
import logging
from collections import defaultdict
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Supported PyTorch file extensions
PYTORCH_EXTENSIONS = {'.bin', '.pth', '.pt', '.ckpt'}
def detect_shared_tensors(tensor_dict):
"""
Detect tensors that share memory (same data_ptr).
Args:
tensor_dict (dict): Dictionary of tensors
Returns:
dict: Mapping from data_ptr to list of tensor names that share that memory
"""
shared_tensors = defaultdict(list)
for name, tensor in tensor_dict.items():
if isinstance(tensor, torch.Tensor):
data_ptr = tensor.data_ptr()
shared_tensors[data_ptr].append(name)
# Only return groups with more than one tensor
return {ptr: names for ptr, names in shared_tensors.items() if len(names) > 1}
def convert_pytorch_to_safetensors(input_path, output_path=None, force=False, use_save_model=True):
"""
Convert a PyTorch file (.bin, .pth, .pt, .ckpt) to SafeTensors format.
Args:
input_path (str): Path to the input PyTorch file
output_path (str, optional): Path for the output .safetensors file
force (bool): Whether to overwrite existing output files
use_save_model (bool): Whether to use save_model for proper shared tensor handling
Returns:
str: Path to the created SafeTensors file
"""
input_path = Path(input_path)
if not input_path.exists():
raise FileNotFoundError(f"Input file not found: {input_path}")
if input_path.suffix.lower() not in PYTORCH_EXTENSIONS:
raise ValueError(f"Input file must have one of these extensions: {', '.join(PYTORCH_EXTENSIONS)}, got: {input_path.suffix}")
# Generate output path if not provided
if output_path is None:
output_path = input_path.with_suffix('.safetensors')
else:
output_path = Path(output_path)
# Check if output file already exists
if output_path.exists() and not force:
raise FileExistsError(f"Output file already exists: {output_path}. Use --force to overwrite.")
try:
logger.info(f"Loading PyTorch file: {input_path}")
# Load the PyTorch file
# Use map_location='cpu' to avoid GPU memory issues
state_dict = torch.load(input_path, map_location='cpu')
# Handle different possible formats
if isinstance(state_dict, dict):
# Check if it's a full checkpoint with model state
if 'state_dict' in state_dict:
tensors = state_dict['state_dict']
logger.info("Found 'state_dict' key in checkpoint")
elif 'model' in state_dict:
tensors = state_dict['model']
logger.info("Found 'model' key in checkpoint")
elif 'model_state_dict' in state_dict:
tensors = state_dict['model_state_dict']
logger.info("Found 'model_state_dict' key in checkpoint")
else:
# Assume the dict itself contains the tensors
tensors = state_dict
logger.info("Using entire dict as tensor data")
else:
raise ValueError(f"Unexpected data type in PyTorch file: {type(state_dict)}")
# Verify all values are tensors
for key, value in tensors.items():
if not isinstance(value, torch.Tensor):
logger.warning(f"Skipping non-tensor key '{key}' with type {type(value)}")
continue
# Filter to only include tensors
tensor_dict = {k: v for k, v in tensors.items() if isinstance(v, torch.Tensor)}
if not tensor_dict:
raise ValueError("No tensors found in the input file")
logger.info(f"Found {len(tensor_dict)} tensors to convert")
# Check for shared tensors
shared_groups = detect_shared_tensors(tensor_dict)
if shared_groups:
logger.warning("Detected tensors that share memory:")
for data_ptr, names in shared_groups.items():
logger.warning(f" Shared memory group: {names}")
if use_save_model:
logger.info("Using save_model to handle shared tensors properly")
else:
logger.warning("Using save_file - this may cause memory duplication!")
# Save as SafeTensors
logger.info(f"Saving SafeTensors file: {output_path}")
if use_save_model and shared_groups:
# Create a simple model-like object to use with save_model
class SimpleModel:
def __init__(self, state_dict):
for key, value in state_dict.items():
setattr(self, key.replace('.', '_'), value)
def state_dict(self):
return tensor_dict
model = SimpleModel(tensor_dict)
save_model(model, output_path)
else:
# Use standard save_file
save_file(tensor_dict, output_path)
# Compare file sizes
input_size = input_path.stat().st_size / (1024 * 1024) # MB
output_size = output_path.stat().st_size / (1024 * 1024) # MB
logger.info(f"Conversion completed successfully!")
logger.info(f"Input size: {input_size:.2f} MB")
logger.info(f"Output size: {output_size:.2f} MB")
logger.info(f"Size difference: {output_size - input_size:.2f} MB")
return str(output_path)
except Exception as e:
logger.error(f"Error during conversion: {str(e)}")
raise
def convert_directory(input_dir, output_dir=None, force=False, recursive=False, use_save_model=True):
"""
Convert all PyTorch files (.bin, .pth, .pt, .ckpt) in a directory to SafeTensors format.
Args:
input_dir (str): Directory containing PyTorch files
output_dir (str, optional): Output directory (defaults to input_dir)
force (bool): Whether to overwrite existing files
recursive (bool): Whether to search subdirectories
use_save_model (bool): Whether to use save_model for proper shared tensor handling
"""
input_dir = Path(input_dir)
if not input_dir.exists() or not input_dir.is_dir():
raise ValueError(f"Input directory does not exist: {input_dir}")
if output_dir is None:
output_dir = input_dir
else:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Find all PyTorch files
pytorch_files = []
for ext in PYTORCH_EXTENSIONS:
if recursive:
pytorch_files.extend(list(input_dir.rglob(f"*{ext}")))
else:
pytorch_files.extend(list(input_dir.glob(f"*{ext}")))
if not pytorch_files:
logger.info(f"No PyTorch files found in {input_dir}")
return
logger.info(f"Found {len(pytorch_files)} PyTorch files to convert")
converted_count = 0
for pytorch_file in pytorch_files:
try:
# Maintain directory structure if using recursive mode
if recursive and output_dir != input_dir:
rel_path = pytorch_file.relative_to(input_dir)
output_file = output_dir / rel_path.with_suffix('.safetensors')
output_file.parent.mkdir(parents=True, exist_ok=True)
else:
output_file = output_dir / pytorch_file.with_suffix('.safetensors').name
convert_pytorch_to_safetensors(pytorch_file, output_file, force=force, use_save_model=use_save_model)
converted_count += 1
except Exception as e:
logger.error(f"Failed to convert {pytorch_file}: {str(e)}")
continue
logger.info(f"Successfully converted {converted_count}/{len(pytorch_files)} files")
def main():
parser = argparse.ArgumentParser(
description="Convert PyTorch files (.bin, .pth, .pt, .ckpt) to SafeTensors format",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Convert a single file
python convert_to_safetensors.py model.pth
python convert_to_safetensors.py checkpoint.ckpt
python convert_to_safetensors.py weights.bin
# Convert with custom output name
python convert_to_safetensors.py model.pt -o model_safe.safetensors
# Convert with save_file method (may duplicate shared tensors)
python convert_to_safetensors.py model.pt --use-save-file
# Convert all PyTorch files in a directory
python convert_to_safetensors.py /path/to/models/ --directory
# Convert recursively with force overwrite
python convert_to_safetensors.py /path/to/models/ --directory --recursive --force
"""
)
parser.add_argument("input", help="Input PyTorch file (.bin, .pth, .pt, .ckpt) or directory")
parser.add_argument("-o", "--output", help="Output file or directory path")
parser.add_argument("-d", "--directory", action="store_true",
help="Process all PyTorch files in the input directory")
parser.add_argument("-r", "--recursive", action="store_true",
help="Search subdirectories recursively (only with --directory)")
parser.add_argument("-f", "--force", action="store_true",
help="Overwrite existing output files")
parser.add_argument("--use-save-file", action="store_true",
help="Use save_file instead of save_model (may duplicate shared tensors)")
parser.add_argument("-v", "--verbose", action="store_true",
help="Enable verbose logging")
args = parser.parse_args()
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
# Determine whether to use save_model (default) or save_file
use_save_model = not args.use_save_file
try:
if args.directory:
convert_directory(args.input, args.output, args.force, args.recursive, use_save_model)
else:
convert_pytorch_to_safetensors(args.input, args.output, args.force, use_save_model)
except Exception as e:
logger.error(f"Conversion failed: {str(e)}")
return 1
return 0
if __name__ == "__main__":
exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment