Created
January 28, 2023 16:37
-
-
Save IzumiSatoshi/e6742cfa10ef9a0ae9cb2f56eac99c34 to your computer and use it in GitHub Desktop.
diffusers形式のInstructPix2Pixを通常モデルとマージ
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 動作未確認 | |
import torch | |
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler, StableDiffusionPipeline | |
import gc | |
# load models | |
ip2p_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained("/content/drive/MyDrive/models/diffusers/ip2p", torch_dtype=torch.float16) | |
base_pipe = StableDiffusionPipeline.from_pretrained("/content/drive/MyDrive/models/diffusers/sd15_fp16", torch_dtype=torch.float16) | |
target_pipe = StableDiffusionPipeline.from_pretrained("/content/drive/MyDrive/models/diffusers/dedede_fp16", torch_dtype=torch.float16) | |
# get state dict | |
ip2p_dict = ip2p_pipe.unet.state_dict() | |
base_dict = base_pipe.unet.state_dict() | |
target_dict = target_pipe.unet.state_dict() | |
# merged = ip2p + target - base | |
for key in ip2p_dict.keys(): | |
# skip conv_in | |
if key == "conv_in.weight": | |
print("pass") | |
else: | |
ip2p_dict[key] = ip2p_dict[key] + target_dict[key] - base_dict[key] | |
# save コンフィグファイルが元のip2pのままなので、何か不具合が出るかもしれない | |
ip2p_pipe.unet.load_state_dict(ip2p_dict) | |
ip2p_pipe.vae = target_pipe.vae | |
ip2p_pipe.save_pretrained("/content/drive/MyDrive/models/diffusers/dedede_ip2p", safe_serialization=True) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment