Skip to content

Instantly share code, notes, and snippets.

@RassilonSleeps
Last active December 4, 2022 08:25
Show Gist options
  • Save RassilonSleeps/dbd98723398570f701087f6340080930 to your computer and use it in GitHub Desktop.
Save RassilonSleeps/dbd98723398570f701087f6340080930 to your computer and use it in GitHub Desktop.
Convert all .ckpt files in all subdirectories to .safetensors files, NAI fix
# Got a bunch of .ckpt files to convert?
# Here's a handy script to take care of all that for you!
# Original .ckpt files are not touched!
# Make sure you have enough disk space! You are going to DOUBLE the size of your models folder!
#
# First, run:
# pip install torch torchsde==0.2.5 safetensors==0.2.5
#
# Place this file in the **SAME DIRECTORY** as all of your .ckpt files, open a command prompt for that folder, and run:
# python convert_to_safe.py
#
# Original script by @xrpgame https://github.com/xrpgame
# https://gist.github.com/xrpgame/8f756f99b00b02697edcd5eec5202c59
#
# Edited by @Tumppi066 for use with folders https://github.com/Tumppi066/
# https://gist.github.com/Tumppi066/42482956139d79cb7c05e0b8f3cfef69
#
# Edited by @RassilonSleeps for NAI model compatibility https://github.com/RassilonSleeps
# https://gist.github.com/RassilonSleeps/edfb630819b95307270efa8450163bc1
import os
import torch
from safetensors.torch import save_file
files = os.listdir()
# Loop through all files in the folder to find the .ckpt files
models = []
safeTensors = []
for path, subdirs, files in os.walk(os.path.abspath(os.getcwd())):
for name in files:
if name.lower().endswith('.ckpt'):
models.append(os.path.join(path, name))
if name.lower().endswith('.safetensors'):
safeTensors.append(os.path.join(path, name))
if len(models) == 0:
print('\033[91m> No .ckpt files found in this directory ({}).\033[0m'.format(os.path.abspath(os.getcwd())))
input('> Press enter to exit... ')
exit()
print(f"\n\033[92m> Found {len(models)} .ckpt files to convert.\033[0m")
for model in models:
print(str(models.index(model)+1) +": "+ model.split("\\")[-1])
input("> Press enter to continue... ")
print("\n")
for index in range(len(models)):
f = models[index]
modelName = f.split("\\")[-1] # This is for easy printing (without printing the full path)
tensorName = f"{modelName.replace('.ckpt', '')}.safetensors"
fn = f"{f.replace('.ckpt', '')}.safetensors"
if fn in safeTensors:
# Print the model name and skip it if it already exists in yellow
print(f"\033[33m\n> Skipping {modelName}, as {tensorName} already exists.\033[0m")
continue
print(f'\n> Loading {modelName} ({index+1}/{len(models)})...')
try:
with torch.no_grad():
weights = torch.load(f)["state_dict"]
weights.pop("state_dict")
fn = f"{f.replace('.ckpt', '')}.safetensors"
print(f'Saving {tensorName}...')
save_file(weights, fn)
except Exception as ex:
print(f'ERROR converting {modelName}: {ex}')
print("\n\033[92mDone!\033[0m")
input("> Press enter to exit... ")
exit()
@JustAnOkapi
Copy link

Edited this script to add weights.pop("state_dict") for converting NAI models/merges.

This should probably be put in a try or only done for NAI as it causes errors on nonNAI models

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment