Last active
July 21, 2025 00:51
-
-
Save marduk191/3a924758c4f903b26ad1e28f79090f8e to your computer and use it in GitHub Desktop.
Pt to safetensors converter with minimal checking. fp32 output
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Minimal SafeTensors Converter | |
Converts PyTorch model files (.pt, .pth) to .safetensors format | |
Usage: python converter.py <input_path> [output_path] | |
by:marduk191 | |
https://github.com/marduk191 | |
""" | |
import os | |
import sys | |
import torch | |
import warnings | |
from safetensors.torch import save_file | |
from collections import OrderedDict | |
from typing import Any, Dict, Union | |
# Suppress warnings | |
warnings.filterwarnings("ignore") | |
def flatten_dict(d: Dict[str, Any], parent_key: str = "", sep: str = "_") -> OrderedDict[str, Any]: | |
"""Flattens nested dicts to single-level""" | |
items = [] | |
for k, v in d.items(): | |
new_key = f"{parent_key}{sep}{k}" if parent_key else k | |
if isinstance(v, (dict, OrderedDict)): | |
items.extend(flatten_dict(v, new_key, sep=sep).items()) | |
else: | |
items.append((new_key, v)) | |
return OrderedDict(items) | |
def convert_to_float32(tensor: Union[torch.Tensor, list, tuple, dict]) -> Union[torch.Tensor, list, tuple, dict]: | |
"""Converts all tensors to float32""" | |
if isinstance(tensor, torch.Tensor): | |
return tensor.float().contiguous() | |
elif isinstance(tensor, (list, tuple)): | |
return type(tensor)(convert_to_float32(t) for t in tensor) | |
elif isinstance(tensor, dict): | |
return {k: convert_to_float32(v) for k, v in tensor.items()} | |
return tensor | |
def get_state_dict(checkpoint: Union[torch.nn.Module, Dict[str, Any]]) -> Dict[str, Any]: | |
"""Gets state dict from checkpoint""" | |
if isinstance(checkpoint, torch.nn.Module): | |
return checkpoint.state_dict() | |
elif isinstance(checkpoint, dict): | |
return checkpoint.get("state_dict", checkpoint) | |
raise ValueError("Unsupported checkpoint format") | |
def convert_file(input_file: str, output_file: str) -> bool: | |
"""Convert single file to safetensors format""" | |
try: | |
# Try loading with weights_only=True first | |
try: | |
checkpoint = torch.load(input_file, map_location="cpu", weights_only=True) | |
except Exception: | |
# Fallback to weights_only=False | |
checkpoint = torch.load(input_file, map_location="cpu", weights_only=False) | |
state_dict = get_state_dict(checkpoint) | |
processed_state_dict = convert_to_float32(flatten_dict(state_dict)) | |
# Ensure tensors are contiguous | |
processed_state_dict = { | |
k: v.contiguous() if isinstance(v, torch.Tensor) else v | |
for k, v in processed_state_dict.items() | |
} | |
save_file(processed_state_dict, output_file) | |
return True | |
except Exception as e: | |
print(f"Failed to convert {input_file}: {e}") | |
return False | |
def main(): | |
if len(sys.argv) < 2: | |
print("Usage: python converter.py <input_path> [output_path]") | |
sys.exit(1) | |
input_path = os.path.abspath(sys.argv[1]) | |
if os.path.isfile(input_path): | |
# Single file | |
if len(sys.argv) >= 3: | |
output_path = sys.argv[2] | |
else: | |
output_path = input_path.replace('.pt', '.safetensors').replace('.pth', '.safetensors') | |
os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
success = convert_file(input_path, output_path) | |
print(f"Conversion {'successful' if success else 'failed'}") | |
elif os.path.isdir(input_path): | |
# Directory | |
output_dir = sys.argv[2] if len(sys.argv) >= 3 else os.path.join(input_path, "converted") | |
os.makedirs(output_dir, exist_ok=True) | |
files = [f for f in os.listdir(input_path) if f.endswith(('.pt', '.pth'))] | |
successful = 0 | |
for file in files: | |
input_file = os.path.join(input_path, file) | |
output_file = os.path.join(output_dir, file.replace('.pt', '.safetensors').replace('.pth', '.safetensors')) | |
if convert_file(input_file, output_file): | |
successful += 1 | |
print(f"✓ {file}") | |
else: | |
print(f"✗ {file}") | |
print(f"Converted {successful}/{len(files)} files") | |
else: | |
print(f"Error: {input_path} is not a valid file or directory") | |
sys.exit(1) | |
if __name__ == "__main__": | |
main() |
Usage: python minimal_safetensors_converter.py <input_path> [output_path]
ex:
-
explicit paths: python minimal_safetensors_converter.py input.pt output.safetensors
-
ex2: python minimal_safetensors_converter.py c:\myinputfile.pt d:\outputs\myoutputfile.safetensors
-
Convert in Place: python minimal_safetensors_converter.py myinput.pt
-
batch: python minimal_safetensors_converter.py c:\input_folder f:\folder\folder\output_folder
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example:
python ./minimal_safetensors_converter.py H:\convert\in\filename.pth C:\this\that\filename.safetensors