Created
November 12, 2023 18:27
-
-
Save rockerBOO/6d2dbc7827c83bf4273e7381636ce9ff to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import torch | |
from safetensors.torch import load_file, safe_open | |
from library import model_util | |
def load_state_dict(file_name, dtype): | |
if model_util.is_safetensors(file_name): | |
sd = load_file(file_name) | |
with safe_open(file_name, framework="pt") as f: | |
metadata = f.metadata() | |
else: | |
sd = torch.load(file_name, map_location="cpu") | |
metadata = None | |
for key in list(sd.keys()): | |
if type(sd[key]) == torch.Tensor: | |
sd[key] = sd[key].to(dtype) | |
return sd, metadata | |
def get_norms(state_dict, device): | |
downkeys = [] | |
upkeys = [] | |
alphakeys = [] | |
norms = [] | |
longest_key = 0 | |
for key in state_dict.keys(): | |
if "lora_down" in key and "weight" in key: | |
downkeys.append(key) | |
upkeys.append(key.replace("lora_down", "lora_up")) | |
alphakeys.append(key.replace("lora_down.weight", "alpha")) | |
for i in range(len(downkeys)): | |
down = state_dict[downkeys[i]].to(device) | |
up = state_dict[upkeys[i]].to(device) | |
alpha = state_dict[alphakeys[i]].to(device) | |
dim = down.shape[0] | |
scale = alpha / dim | |
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): | |
updown = ( | |
(up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)) | |
.unsqueeze(2) | |
.unsqueeze(3) | |
) | |
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): | |
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute( | |
1, 0, 2, 3 | |
) | |
else: | |
updown = up @ down | |
updown *= scale | |
save_key = downkeys[i].replace(".lora_down", "") | |
longest_key = ( | |
len(save_key) if len(save_key) > longest_key else longest_key | |
) | |
norms.append({save_key: updown.norm().item()}) | |
return norms, longest_key | |
def main(args): | |
lora_sd, metadata = load_state_dict(args.model, torch.float32) | |
norms, longest_key = get_norms( | |
lora_sd, "cuda" if torch.cuda.is_available() else "cpu" | |
) | |
for norm in norms: | |
for k, v in norm.items(): | |
print(f"{k:<{longest_key}} {v}") | |
if __name__ == "__main__": | |
argparser = argparse.ArgumentParser( | |
description="Check the norm values for the weights in a LoRA model" | |
) | |
argparser.add_argument("model", help="LoRA model to check the norms of") | |
args = argparser.parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment