Created
November 28, 2022 04:53
-
-
Save zer0TF/8f756f99b00b02697edcd5eec5202c59 to your computer and use it in GitHub Desktop.
Convert all CKPT files to SAFETENSOR files in a directory
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Got a bunch of .ckpt files to convert? | |
# Here's a handy script to take care of all that for you! | |
# Original .ckpt files are not touched! | |
# Make sure you have enough disk space! You are going to DOUBLE the size of your models folder! | |
# | |
# First, run: | |
# pip install torch torchsde==0.2.5 safetensors==0.2.5 | |
# | |
# Place this file in the **SAME DIRECTORY** as all of your .ckpt files, open a command prompt for that folder, and run: | |
# python convert_to_safe.py | |
import os | |
import torch | |
from safetensors.torch import save_file | |
files = os.listdir() | |
for f in files: | |
if f.lower().endswith('.ckpt'): | |
print(f'Loading {f}...') | |
fn = f"{f.replace('.ckpt', '')}.safetensors" | |
if fn in files: | |
print(f'Skipping, as {fn} already exists.') | |
continue | |
try: | |
with torch.no_grad(): | |
weights = torch.load(f)["state_dict"] | |
fn = f"{f.replace('.ckpt', '')}.safetensors" | |
print(f'Saving {fn}...') | |
save_file(weights, fn) | |
except Exception as ex: | |
print(f'ERROR converting {f}: {ex}') | |
print('Done!') |
is there any way to do this the other way around?
This is just a guess, but since the difference between the models is Python injection, have you tried simply renaming the model to .ckpt?
if 'state_dict' in weights: weights.pop("state_dict')Maybe ? (So it doesn't rely on the filename)
I agree, that would be better. It also supports any model with that dictionary.
if 'state_dict' in weights: weights.pop("state_dict')
Please tell me where to add these lines to the code.
After line 28
Hey, I'm getting this error
"ERROR converting xxxx.ckpt: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU."
how to fix this? I'm running it on VM with no GPU.
Thank you.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Maybe ? (So it doesn't rely on the filename)