Created
October 22, 2022 15:23
-
-
Save Quasimondo/f344659f57dc15bd7892a969bd58ac67 to your computer and use it in GitHub Desktop.
Quick script to merge finetuned StabilityAI autoencoder into RunwayML Stable Diffusion 1.5 checkpoint
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
#USE AT YOUR OWN RISK | |
#local path to runwayML SD 1.5 checkpoint (https://huggingface.co/runwayml/stable-diffusion-v1-5) | |
ckpt_15 = "./v1-5-pruned-emaonly.ckpt" | |
#local path to StabilityAI finetuned autoencoder (https://huggingface.co/stabilityai/sd-vae-ft-mse) | |
ckpt_vae = "./vae-ft-mse-840000-ema-pruned.ckpt" | |
#path to save merged model to | |
ckpt_out = "./v1-5-pruned-emaonly_new_vae.ckpt" | |
pl_sd = torch.load(ckpt_15, map_location="cpu") | |
sd = pl_sd["state_dict"] | |
over_sd = torch.load(ckpt_vae,map_location="cpu")["state_dict"] | |
sdk = sd.keys() | |
for key in over_sd.keys(): | |
if "first_stage_model."+key in sdk: | |
sd["first_stage_model."+key] = over_sd[key] | |
print(key,"overwritten") | |
torch.save(pl_sd,ckpt_out) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hard to tell - I suspect you must either be using a wrong checkpoint or the checkpoints at the download locations have been changed.