Created
May 14, 2025 16:14
-
-
Save envy-ai/c09de353004de752290185bb9cb5099d to your computer and use it in GitHub Desktop.
Merge ai-toolkit lora into diffusers HiDream model.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import os | |
import argparse | |
import torch | |
from safetensors.torch import load_file, save_file | |
# NOTE: When done, move the newly created transformer into the transformer/ subdirectory of your diffusers model, and | |
# remove the old .safetensors files and the .index.json file (which is no longer needed because it's not sharded) | |
def manual_merge( | |
base_transformer_dir: str, | |
lora_path: str, | |
output_path: str, | |
lora_scale: float = 1.0, | |
): | |
# 1) Load base transformer shards | |
base_sd = {} | |
for fname in os.listdir(base_transformer_dir): | |
if fname.endswith(".safetensors") or fname.endswith(".bin"): | |
shard = load_file(os.path.join(base_transformer_dir, fname), device="cuda") | |
base_sd.update(shard) | |
# Print base model keys | |
print(f"Base model keys: {list(base_sd.keys())}") | |
# 2) Load your LoRA/DoRA file | |
print(f"Loading LoRA/DoRA from {lora_path}...") | |
lora_sd = load_file(lora_path, device="cuda") | |
# 3) Collect A/B/mag pieces | |
deltas = {} | |
for k, v in lora_sd.items(): | |
print(f"Processing key: {k}") | |
if k.endswith(".lora_A.weight"): | |
mod = k[:-len(".lora_A.weight")] | |
deltas.setdefault(mod, {})["A"] = v | |
elif k.endswith(".lora_B.weight"): | |
mod = k[:-len(".lora_B.weight")] | |
deltas.setdefault(mod, {})["B"] = v | |
elif k.endswith(".lora_magnitude_vector"): | |
mod = k[:-len(".lora_magnitude_vector")] | |
deltas.setdefault(mod, {})["mag"] = v | |
# 4) Apply each update: W_new = W_base + scale * (B @ A) * mag | |
for mod, pieces in deltas.items(): | |
print(f"Applying LoRA/DoRA to module: {mod}") | |
A = pieces.get("A"); B = pieces.get("B") | |
if A is None or B is None: | |
print(f"Skipping module {mod} due to missing A or B") | |
continue | |
delta = B @ A # shape (out_dim, in_dim) | |
mag = pieces.get("mag") | |
if mag is not None: | |
delta = delta * mag.unsqueeze(1) | |
delta = delta * lora_scale | |
base_key = f"{mod}.weight" | |
remove_prefix = "diffusion_model." | |
if base_key.startswith(remove_prefix): | |
base_key = base_key[len(remove_prefix):] | |
if base_key not in base_sd: | |
raise KeyError(f"Base weight not found for module '{mod}' (looked for '{base_key}')") | |
base_sd[base_key] = base_sd[base_key] + delta | |
# 5) Save one big safetensors file | |
os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
save_file(base_sd, output_path) | |
print(f"✅ Merged safetensors written to {output_path}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--base_transformer_dir", required=True, | |
help="e.g. path/to/HiDream-I1-Full/transformer" | |
) | |
parser.add_argument( | |
"--lora_path", required=True, | |
help="Path to your LoRA/DoRA safetensors" | |
) | |
parser.add_argument( | |
"--output_path", required=True, | |
help="Path to output diffusion_pytorch_model.safetensors" | |
) | |
parser.add_argument( | |
"--lora_scale", type=float, default=1.0, | |
help="Global LoRA/DoRA scale" | |
) | |
args = parser.parse_args() | |
manual_merge( | |
base_transformer_dir=args.base_transformer_dir, | |
lora_path=args.lora_path, | |
output_path=args.output_path, | |
lora_scale=args.lora_scale, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment