Created
October 9, 2024 15:15
Merge loras together that have the same weight naming convention.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gc | |
from collections import OrderedDict | |
# leave in this if for autoformatting purposes | |
if True: | |
import torch | |
from safetensors.torch import load_file, save_file | |
def flush(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
metadata = OrderedDict() | |
metadata["format"] = "pt" | |
# you can add as many as you want. Be sure to adjust the weights accordingly 1.0 is full weight | |
lora_paths = [ | |
("/path/to/lora1.safetensors", 0.33), | |
("/path/to/lora2.safetensors", 0.33), | |
("/path/to/lora3.safetensors", 0.33), | |
] | |
output_path = "/save/path/merged_lora.safetensors" | |
dtype = torch.float16 | |
device = torch.device("cpu") | |
output_state_dict = {} | |
def pad_tensor(tensor, target_shape): | |
current_shape = tensor.shape | |
padding = [] | |
for i in range(len(current_shape) - 1, -1, -1): | |
if i < len(target_shape): | |
padding.extend([0, max(0, target_shape[i] - current_shape[i])]) | |
else: | |
padding.extend([0, 0]) | |
return torch.nn.functional.pad(tensor, padding) | |
for idx, (lora_path, multiplier) in enumerate(lora_paths): | |
print(f"Loading LoRA {idx + 1}/{len(lora_paths)}") | |
lora_state_dict = load_file(lora_path) | |
for key, value in lora_state_dict.items(): | |
value = value.to(torch.float32) * multiplier | |
if key not in output_state_dict: | |
output_state_dict[key] = value | |
else: | |
target_shape = torch.max(torch.tensor( | |
output_state_dict[key].shape), torch.tensor(value.shape)) | |
output_state_dict[key] = pad_tensor( | |
output_state_dict[key], target_shape) | |
value = pad_tensor(value, target_shape) | |
output_state_dict[key] += value | |
flush() | |
for key, value in output_state_dict.items(): | |
output_state_dict[key] = value.to('cpu', dtype) | |
print("Saving model...") | |
save_file(output_state_dict, output_path, metadata=metadata) | |
print(f"Successfully saved merge to to {output_path}") | |
print("Done.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment