Created
September 14, 2022 13:48
-
-
Save opparco/62511cea8b185fc5238c62099a89f593 to your computer and use it in GitHub Desktop.
Generate a mixed model of the two models.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import torch | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--ckpt0", | |
type=str, | |
default="sd-v1-4.ckpt", | |
help="path to checkpoint of model 0", | |
) | |
parser.add_argument( | |
"--ckpt1", | |
type=str, | |
default="wd-v1-2-full-ema.ckpt", | |
help="path to checkpoint of model 1", | |
) | |
parser.add_argument( | |
"--alpha", | |
type=float, | |
default=0.5, | |
help="alpha: (1-alpha) * theta_0 + alpha * theta_1", | |
) | |
parser.add_argument( | |
"--outpath", | |
type=str, | |
default="tempered-waifu.ckpt", | |
help="path to checkpoint of tempered model", | |
) | |
opt = parser.parse_args() | |
model_0 = torch.load(opt.ckpt0) | |
model_1 = torch.load(opt.ckpt1) | |
theta_0 = model_0['state_dict'] | |
theta_1 = model_1['state_dict'] | |
alpha = opt.alpha | |
outpath = opt.outpath | |
for key in theta_0.keys(): | |
if 'model' in key and key in theta_1: | |
theta_0[key] = (1-alpha) * theta_0[key] + alpha * theta_1[key] | |
for key in theta_1.keys(): | |
if 'model' in key and key not in theta_0: | |
theta_0[key] = theta_1[key] | |
print(f"Save checkpoint of tempered model: \n{outpath} \n") | |
torch.save(model_0, opt.outpath) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment