Skip to content

Instantly share code, notes, and snippets.

@ProGamerGov
Last active December 26, 2023 07:15
Show Gist options
  • Save ProGamerGov/70061a08e3a2da6e9ed83e145ea24a70 to your computer and use it in GitHub Desktop.
Save ProGamerGov/70061a08e3a2da6e9ed83e145ea24a70 to your computer and use it in GitHub Desktop.
Replace the VAE in a Stable Diffusion model with a new VAE. Tested on v1.4 & v1.5 SD models
# Script by https://github.com/ProGamerGov
import copy
import torch
# Path to model and VAE files that you want to merge
vae_file_path = "vae-ft-mse-840000-ema-pruned.ckpt"
model_file_path = "v1-5-pruned-emaonly.ckpt"
# Name to use for new model file
new_model_name = "v1-5-pruned-emaonly_ema_vae.ckpt"
# Load files
vae_model = torch.load(vae_file_path, map_location="cpu")
full_model = torch.load(model_file_path, map_location="cpu")
# Replace VAE in model file with new VAE
vae_dict = {k: v for k, v in vae_model["state_dict"].items() if k[0:4] not in ["loss", "mode"]}
for k, _ in vae_dict.items():
key_name = "first_stage_model." + k
full_model['state_dict'][key_name] = copy.deepcopy(vae_model["state_dict"][k])
# Save model with new VAE
torch.save(full_model, new_model_name)
@bond007alex
Copy link

oh, this is super useful, thanks! i had no idea vae were stuffed into some of the .ckpt files.

@zhuofengli
Copy link

zhuofengli commented Jan 18, 2023

I ran into:
KeyError: 'state_dict'
any idea how to solve it?

@ProGamerGov
Copy link
Author

@zhuofengli What model are you using? This script will only work on 1.x models.

@Astropulse
Copy link

Astropulse commented Jan 26, 2023

This error is caused by some merged models, which have been flattened to only the 'state_dict' component already.

Editing the script to this fixes the issues by first checking if the model contains 'state_dict' or is already flattened.

# Check for flattened (merged) models
if 'state_dict' in full_model:
    full_model = full_model["state_dict"]
if 'state_dict' in vae_model:
    vae_model = vae_model["state_dict"]

# Replace VAE in model file with new VAE
vae_dict = {k: v for k, v in vae_model.items() if k[0:4] not in ["loss", "mode"]}
for k, _ in vae_dict.items():
    key_name = "first_stage_model." + k
    full_model[key_name] = copy.deepcopy(vae_model[k])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment