Skip to content

Instantly share code, notes, and snippets.

@zer0TF
Created November 28, 2022 04:53
Show Gist options
  • Save zer0TF/8f756f99b00b02697edcd5eec5202c59 to your computer and use it in GitHub Desktop.
Save zer0TF/8f756f99b00b02697edcd5eec5202c59 to your computer and use it in GitHub Desktop.
Convert all CKPT files to SAFETENSOR files in a directory
# Got a bunch of .ckpt files to convert?
# Here's a handy script to take care of all that for you!
# Original .ckpt files are not touched!
# Make sure you have enough disk space! You are going to DOUBLE the size of your models folder!
#
# First, run:
# pip install torch torchsde==0.2.5 safetensors==0.2.5
#
# Place this file in the **SAME DIRECTORY** as all of your .ckpt files, open a command prompt for that folder, and run:
# python convert_to_safe.py
import os
import torch
from safetensors.torch import save_file
files = os.listdir()
for f in files:
if f.lower().endswith('.ckpt'):
print(f'Loading {f}...')
fn = f"{f.replace('.ckpt', '')}.safetensors"
if fn in files:
print(f'Skipping, as {fn} already exists.')
continue
try:
with torch.no_grad():
weights = torch.load(f)["state_dict"]
fn = f"{f.replace('.ckpt', '')}.safetensors"
print(f'Saving {fn}...')
save_file(weights, fn)
except Exception as ex:
print(f'ERROR converting {f}: {ex}')
print('Done!')
@cooperdk
Copy link

if 'state_dict' in weights:
    weights.pop("state_dict')

Maybe ? (So it doesn't rely on the filename)

I agree, that would be better. It also supports any model with that dictionary.

@imacopypaster
Copy link

if 'state_dict' in weights:
    weights.pop("state_dict')

Please tell me where to add these lines to the code.

@Narsil
Copy link

Narsil commented Jan 16, 2023

After line 28

@AskerCPU87456
Copy link

Hey, I'm getting this error

"ERROR converting xxxx.ckpt: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU."

how to fix this? I'm running it on VM with no GPU.

Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment