Created
June 12, 2025 17:19
-
-
Save jaretburkett/2453b8203d62416413ef0e1988aa1345 to your computer and use it in GitHub Desktop.
Stack AI-Toolkit LoRAs by rank. Be careful as they get big fast.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gc | |
from collections import OrderedDict | |
import os | |
# leave in this if for autoformatting purposes | |
if True: | |
import torch | |
from safetensors.torch import load_file, save_file | |
def flush(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
metadata = OrderedDict() | |
metadata["format"] = "pt" | |
# you can add as many as you want. Be sure to adjust the weights accordingly 1.0 is full weight. | |
lora_paths = [ | |
("/path/to/lora.safetensors", 0.2), # (path, lora_weight) | |
("/path/to/lora.safetensors", 0.2), # (path, lora_weight) | |
("/path/to/lora.safetensors", 0.2), # (path, lora_weight) | |
] | |
output_path = "/path/to/save.safetensors" | |
dtype = torch.bfloat16 | |
device = torch.device("cpu") | |
output_state_dict = {} | |
for idx, (lora_path, multiplier) in enumerate(lora_paths): | |
print(f"Loading LoRA {idx + 1}/{len(lora_paths)}") | |
lora_state_dict = load_file(lora_path) | |
for key, value in lora_state_dict.items(): | |
value = value.to(torch.float32) * multiplier | |
if key not in output_state_dict: | |
output_state_dict[key] = value | |
else: | |
if 'lora_A' in key: # lora down. ran is in idx 0 | |
torch.cat( | |
[output_state_dict[key], value], dim=0, out=output_state_dict[key] | |
) | |
elif 'lora_B' in key: # lora up. ran is in idx 1 | |
torch.cat( | |
[output_state_dict[key], value], dim=1, out=output_state_dict[key] | |
) | |
else: | |
raise ValueError( | |
f"Unexpected key format: {key}. Expected 'lora_A' or 'lora_B'." | |
) | |
flush() | |
for key, value in output_state_dict.items(): | |
output_state_dict[key] = value.to('cpu', dtype) | |
print("Saving model...") | |
os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
save_file(output_state_dict, output_path, metadata=metadata) | |
print(f"Successfully saved merge to to {output_path}") | |
print("Done.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment