Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save envy-ai/c09de353004de752290185bb9cb5099d to your computer and use it in GitHub Desktop.
Save envy-ai/c09de353004de752290185bb9cb5099d to your computer and use it in GitHub Desktop.
Merge ai-toolkit lora into diffusers HiDream model.
#!/usr/bin/env python3
import os
import argparse
import torch
from safetensors.torch import load_file, save_file
# NOTE: When done, move the newly created transformer into the transformer/ subdirectory of your diffusers model, and
# remove the old .safetensors files and the .index.json file (which is no longer needed because it's not sharded)
def manual_merge(
base_transformer_dir: str,
lora_path: str,
output_path: str,
lora_scale: float = 1.0,
):
# 1) Load base transformer shards
base_sd = {}
for fname in os.listdir(base_transformer_dir):
if fname.endswith(".safetensors") or fname.endswith(".bin"):
shard = load_file(os.path.join(base_transformer_dir, fname), device="cuda")
base_sd.update(shard)
# Print base model keys
print(f"Base model keys: {list(base_sd.keys())}")
# 2) Load your LoRA/DoRA file
print(f"Loading LoRA/DoRA from {lora_path}...")
lora_sd = load_file(lora_path, device="cuda")
# 3) Collect A/B/mag pieces
deltas = {}
for k, v in lora_sd.items():
print(f"Processing key: {k}")
if k.endswith(".lora_A.weight"):
mod = k[:-len(".lora_A.weight")]
deltas.setdefault(mod, {})["A"] = v
elif k.endswith(".lora_B.weight"):
mod = k[:-len(".lora_B.weight")]
deltas.setdefault(mod, {})["B"] = v
elif k.endswith(".lora_magnitude_vector"):
mod = k[:-len(".lora_magnitude_vector")]
deltas.setdefault(mod, {})["mag"] = v
# 4) Apply each update: W_new = W_base + scale * (B @ A) * mag
for mod, pieces in deltas.items():
print(f"Applying LoRA/DoRA to module: {mod}")
A = pieces.get("A"); B = pieces.get("B")
if A is None or B is None:
print(f"Skipping module {mod} due to missing A or B")
continue
delta = B @ A # shape (out_dim, in_dim)
mag = pieces.get("mag")
if mag is not None:
delta = delta * mag.unsqueeze(1)
delta = delta * lora_scale
base_key = f"{mod}.weight"
remove_prefix = "diffusion_model."
if base_key.startswith(remove_prefix):
base_key = base_key[len(remove_prefix):]
if base_key not in base_sd:
raise KeyError(f"Base weight not found for module '{mod}' (looked for '{base_key}')")
base_sd[base_key] = base_sd[base_key] + delta
# 5) Save one big safetensors file
os.makedirs(os.path.dirname(output_path), exist_ok=True)
save_file(base_sd, output_path)
print(f"✅ Merged safetensors written to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--base_transformer_dir", required=True,
help="e.g. path/to/HiDream-I1-Full/transformer"
)
parser.add_argument(
"--lora_path", required=True,
help="Path to your LoRA/DoRA safetensors"
)
parser.add_argument(
"--output_path", required=True,
help="Path to output diffusion_pytorch_model.safetensors"
)
parser.add_argument(
"--lora_scale", type=float, default=1.0,
help="Global LoRA/DoRA scale"
)
args = parser.parse_args()
manual_merge(
base_transformer_dir=args.base_transformer_dir,
lora_path=args.lora_path,
output_path=args.output_path,
lora_scale=args.lora_scale,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment