Skip to content

Instantly share code, notes, and snippets.

@maximousblk
Last active December 12, 2022 23:36
Show Gist options
  • Save maximousblk/210b55ea7b082bf2a47e2f200d1efbb1 to your computer and use it in GitHub Desktop.
Save maximousblk/210b55ea7b082bf2a47e2f200d1efbb1 to your computer and use it in GitHub Desktop.
A simple script to convert torch models to safetensors format

SD Pickle to Safetensors Converter

A simple script to convert Stable Diffusion models from pickle format to safetensors format.

Usage

# Clone the script and install dependencies
git clone https://gist.github.com/210b55ea7b082bf2a47e2f200d1efbb1.git pickle2safetensors
cd pickle2safetensors
pip install torch safetensors

# Run the script
python pickle2safetensors.py --input path/to/model.ckpt
*.safetensors
*.ckpt
import os
import torch
import argparse
from safetensors.torch import save_file
# check main
if __name__ != "__main__":
raise Exception("This script is not meant to be imported")
parser = argparse.ArgumentParser(description="Convert a model from pickle to safetensors format")
parser.add_argument("--input", type=str, help="Path to input model in torch format (.ckpt)", required=True)
parser.add_argument("--output", type=str, help="Path to output model (without extension)", default="model", required=False)
parser.add_argument("--fp16", action=argparse.BooleanOptionalAction, help="Whether to use half precision", default=False, required=False)
parser.add_argument("--device", type=str, help="Device to use (defaults to 'cpu')", default="cpu", required=False)
args = parser.parse_args()
print(f"• Loading model from {args.input}...")
weights = torch.load(args.input, map_location=args.device)["state_dict"]
if args.fp16:
print("• Converting to half precision...")
weights = {k: v.half() for k, v in weights.items()}
output_extension = f"{'.fp16' if args.fp16 else ''}.safetensors"
output_file = args.output + output_extension
while os.path.isfile(output_file):
overwrite = input(
f"! Output file '{output_file}' already exists. Overwrite? [y/N]: ")
if overwrite.lower() == "y":
break
else:
filename = input(
"? Please enter a new output file name (without extension): ")
if filename:
output_file = filename + output_extension
print(f"• Saving to {output_file}...")
save_file(weights, output_file)
print("✓ Done!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment