Last active
March 17, 2023 18:48
-
-
Save AmericanPresidentJimmyCarter/1947162f371e601ce183070443f41dc2 to your computer and use it in GitHub Desktop.
Attempting to make small DB patches
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from pathlib import Path | |
import sys | |
import torch | |
THRESHOLD_STRENGTH = 2. | |
DEFAULT_OUT_NAME = 'output.ckpt' | |
parser = argparse.ArgumentParser(description='Create a compressed dreambooth patch or patch weights') | |
parser.add_argument('mode', type=str, help='"compress" or "inflate"') | |
parser.add_argument( | |
'-m', | |
type=str, | |
help='Base model (Compvis)', | |
required=False, | |
) | |
parser.add_argument( | |
'-m2', | |
type=str, | |
help='Dreambooth model (Compvis)', | |
required=False, | |
) | |
parser.add_argument( | |
'-c', | |
type=str, | |
help='Model configuration for compvis model', | |
required=False, | |
) | |
parser.add_argument( | |
'-p', | |
type=str, | |
help='Patch file name', | |
required=False, | |
) | |
parser.add_argument( | |
'-o', | |
type=str, | |
help='Output file name', | |
required=False, | |
) | |
args = parser.parse_args() | |
if args.mode == 'compress' and \ | |
not (Path.is_file(Path(args.m)) and Path.is_file(Path(args.m2))): | |
print('One or both models in not a file, please provide a correct path') | |
sys.exit(1) | |
if args.mode == 'inflate' and \ | |
not (Path.is_file(Path(args.m)) and Path.is_file(Path(args.p))): | |
print('One or both models in not a file, please provide a correct path') | |
sys.exit(1) | |
out_path = args.o or DEFAULT_OUT_NAME | |
if args.mode == 'compress': | |
model = torch.load(args.m) | |
model_db = torch.load(args.m2) | |
m_state = model['state_dict'] | |
m_db_state = model_db['state_dict'] | |
for key in m_state.keys(): | |
if 'model' in key and \ | |
key in m_state and key in m_db_state and \ | |
isinstance(m_state[key], torch.Tensor) and \ | |
isinstance(m_db_state[key], torch.Tensor): | |
if m_state[key].dtype == torch.float32: | |
m_state[key] = m_state[key].half() | |
if m_db_state[key].dtype == torch.float32: | |
m_db_state[key] = m_db_state[key].half() | |
# Diff | |
m_state[key] = m_state[key] - m_db_state[key] | |
# Clamp to ignore small diff | |
threshold_pos = torch.max(m_state[key]) / THRESHOLD_STRENGTH | |
threshold_neg = torch.min(m_state[key]) / THRESHOLD_STRENGTH | |
clamped_high = torch.clone(m_state[key]) | |
clamped_low = torch.clone(m_state[key]) | |
clamped_high[clamped_high < threshold_pos] = 0. | |
clamped_low[clamped_low > threshold_neg] = 0. | |
non_zero_are_one = clamped_high + clamped_low | |
non_zero_are_one[non_zero_are_one != 0.] = 1. | |
# Compress as sparse tensor | |
m_state[key] = non_zero_are_one * m_db_state[key] | |
m_state[key] = m_state[key].to_sparse() | |
del clamped_high | |
del clamped_low | |
del non_zero_are_one | |
if 'model' in key and \ | |
key not in m_state and key in m_db_state and \ | |
isinstance(m_db_state[key], torch.Tensor): | |
m_state[key] = m_db_state[key].to_sparse() | |
del m_db_state | |
del model_db | |
# Save patch | |
torch.save(m_state, out_path) | |
if args.mode == 'inflate': | |
model = torch.load(args.m) | |
model_patch = torch.load(args.p) | |
# Diff the two diffusers unets. | |
m_state = model['state_dict'] | |
m_patch_state = model_patch | |
for key in m_state.keys(): | |
if 'model' in key and \ | |
key in m_state and key in m_patch_state and \ | |
isinstance(m_state[key], torch.Tensor) and \ | |
isinstance(m_patch_state[key], torch.Tensor): | |
if m_state[key].dtype == torch.float32: | |
m_state[key] = m_state[key].half() | |
inflated = m_patch_state[key].to_dense().half() | |
assert m_state[key].size() == inflated.size() | |
# Zero out the altered values. | |
non_zero_are_one = inflated.clone() | |
non_zero_are_one[non_zero_are_one != 0.] = 1. | |
ones = torch.ones_like(non_zero_are_one) | |
zero_are_one = ones - non_zero_are_one | |
m_state[key] = m_state[key] * zero_are_one | |
m_state[key] = m_state[key] + inflated | |
del inflated | |
del non_zero_are_one | |
del ones | |
del zero_are_one | |
if 'model' in key and \ | |
key not in m_state and key in m_patch_state and \ | |
isinstance(m_patch_state[key], torch.Tensor): | |
m_state[key] = m_patch_state[key].to_dense() | |
# Save patched model | |
del m_patch_state | |
del model_patch | |
model['state_dict'] = m_state | |
torch.save(model, out_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment