Skip to content

Instantly share code, notes, and snippets.

@arenasys
Last active December 30, 2022 18:41
Show Gist options
  • Save arenasys/658bc94f4799d324b54e6e36200e0e1f to your computer and use it in GitHub Desktop.
Save arenasys/658bc94f4799d324b54e6e36200e0e1f to your computer and use it in GitHub Desktop.
VAE Extractor
import torch
allowed_keys = [
"encoder",
"decoder",
"quant_conv",
"post_quant_conv"]
def extract_vae(model, half=False):
out = {}
if 'state_dict' in model:
model = model['state_dict']
for k in model:
kk = k
if k.startswith('first_stage_model'):
kk = k.replace('first_stage_model.', '')
if kk.split('.')[0] in allowed_keys:
print("COPYING", k)
if half:
out[kk] = model[k].half()
else:
out[kk] = model[k]
else:
print("IGNORING", k)
return {'state_dict': out}
def extract_cli():
import argparse
import os
parser = argparse.ArgumentParser(description='VAE Extracter')
parser.add_argument('--ckpt', type=str, default=None, required=True, help='path to model ckpt or vae pt')
parser.add_argument('-f', action='store_true', help='half precision (fp16)')
args = parser.parse_args()
in_file = args.ckpt
half = args.f
if not os.path.exists(in_file):
print("checkpoint not found")
return
in_size = os.path.getsize(in_file)
print(f"INPUT {os.path.basename(in_file)}: {in_size*1e-9:.2f} GB\n")
print("LOADING...\n")
model = torch.load(in_file, map_location="cpu")
print("EXTRACTING...\n")
out = extract_vae(model, half)
print("\nSAVING...\n")
tmp_file = in_file.replace(".ckpt", ".vae.pt")
out_file = tmp_file
i = 2
while os.path.exists(out_file):
out_file = tmp_file.replace(".vae.pt", f"-{i}.vae.pt")
i += 1
torch.save(out, out_file)
out_size = os.path.getsize(out_file)
print(f"OUTPUT {os.path.basename(out_file)}: {out_size*1e-9:.2f} GB")
if __name__ == "__main__":
extract_cli()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment