Skip to content

Instantly share code, notes, and snippets.

@tori29umai0123
Last active November 7, 2025 07:08
Show Gist options
  • Save tori29umai0123/2fb47a0eb3d9e1e55002b4beb5292021 to your computer and use it in GitHub Desktop.
Save tori29umai0123/2fb47a0eb3d9e1e55002b4beb5292021 to your computer and use it in GitHub Desktop.
import argparse
import logging
import torch
from safetensors import safe_open
from safetensors.torch import load_file, save_file
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
QWEN_IMAGE_KEYS = [
"time_text_embed.timestep_embedder.linear_1",
"time_text_embed.timestep_embedder.linear_2",
"txt_norm",
"img_in",
"txt_in",
"transformer_blocks.*.img_mod.1",
"transformer_blocks.*.attn.norm_q",
"transformer_blocks.*.attn.norm_k",
"transformer_blocks.*.attn.to_q",
"transformer_blocks.*.attn.to_k",
"transformer_blocks.*.attn.to_v",
"transformer_blocks.*.attn.add_k_proj",
"transformer_blocks.*.attn.add_v_proj",
"transformer_blocks.*.attn.add_q_proj",
"transformer_blocks.*.attn.to_out.0",
"transformer_blocks.*.attn.to_add_out",
"transformer_blocks.*.attn.norm_added_q",
"transformer_blocks.*.attn.norm_added_k",
"transformer_blocks.*.img_mlp.net.0.proj",
"transformer_blocks.*.img_mlp.net.2",
"transformer_blocks.*.txt_mod.1",
"transformer_blocks.*.txt_mlp.net.0.proj",
"transformer_blocks.*.txt_mlp.net.2",
"norm_out.linear",
"proj_out",
]
# ------------------------------------------------------------------------------
# Diffusers → Kohya (default)
# ------------------------------------------------------------------------------
def convert_from_diffusers(prefix, weights_sd):
new_weights_sd = {}
lora_dims = {}
for key, weight in weights_sd.items():
# 対応prefixを柔軟に認識
if key.startswith("diffusion_model."):
key_body = key[len("diffusion_model.") :]
elif key.startswith("transformer."):
key_body = key[len("transformer.") :]
elif key.startswith("unet."):
key_body = key[len("unet.") :]
else:
logger.warning(f"Unexpected key: {key}")
continue
new_key = f"{prefix}{key_body}".replace(".", "_")
new_key = (
new_key.replace("_lora_A_", ".lora_down.")
.replace("_lora_B_", ".lora_up.")
.replace("_lora_down_", ".lora_down.")
.replace("_lora_up_", ".lora_up.")
)
if new_key.endswith("_alpha"):
new_key = new_key.replace("_alpha", ".alpha")
new_weights_sd[new_key] = weight
if "lora_down" in new_key:
lora_name = new_key.split(".")[0]
lora_dims[lora_name] = weight.shape[0]
# alpha追加
for name, dim in lora_dims.items():
alpha_key = f"{name}.alpha"
if alpha_key not in new_weights_sd:
new_weights_sd[alpha_key] = torch.tensor(dim)
return new_weights_sd
# ------------------------------------------------------------------------------
# Kohya (default) → Diffusers
# ------------------------------------------------------------------------------
def convert_to_diffusers(prefix, diffusers_prefix, adapter_name, weights_sd):
diffusers_prefix_with_dot = f"{diffusers_prefix}." if diffusers_prefix else ""
adapter_name = adapter_name or "default"
lora_name_to_module = {}
for key in QWEN_IMAGE_KEYS:
base = key.replace(".", "_")
if "*" not in key:
lora_name_to_module[prefix + base] = key
else:
for i in range(100):
lora_name_to_module[prefix + base.replace("*", str(i))] = key.replace("*", str(i))
# alphaを収集
lora_alphas = {
key.split(".")[0]: weight
for key, weight in weights_sd.items()
if key.startswith(prefix) and "alpha" in key
}
new_sd = {}
for key, weight in weights_sd.items():
if not key.startswith(prefix) or "alpha" in key:
continue
lora_name = key.split(".", 1)[0]
module_name = lora_name_to_module.get(
lora_name,
lora_name[len(prefix) :].replace("_", "."),
)
if "lora_down" in key:
new_key = f"{diffusers_prefix_with_dot}{module_name}.lora_A.{adapter_name}.weight"
dim = weight.shape[0]
elif "lora_up" in key:
new_key = f"{diffusers_prefix_with_dot}{module_name}.lora_B.{adapter_name}.weight"
dim = weight.shape[1]
else:
continue
# alphaスケーリング
if lora_name in lora_alphas:
scale = (lora_alphas[lora_name] / dim).sqrt()
weight = weight * scale
new_sd[new_key] = weight
return new_sd
# ------------------------------------------------------------------------------
# メイン処理
# ------------------------------------------------------------------------------
def convert(input_file, output_file, target_format, diffusers_prefix, adapter_name):
logger.info(f"Loading {input_file}")
weights_sd = load_file(input_file)
with safe_open(input_file, framework="pt") as f:
metadata = dict(f.metadata() or {})
prefix = "lora_unet_"
if target_format == "kohya":
new_weights = convert_from_diffusers(prefix, weights_sd)
elif target_format == "diffusers":
new_weights = convert_to_diffusers(prefix, diffusers_prefix, adapter_name, weights_sd)
else:
raise ValueError(f"Invalid target format: {target_format}")
logger.info(f"Saving {output_file}")
save_file(new_weights, output_file, metadata)
logger.info("✅ Conversion complete.")
def parse_args():
p = argparse.ArgumentParser(description="Convert LoRA between Kohya and Diffusers format")
p.add_argument("--input", required=True, help="Input .safetensors file")
p.add_argument("--output", required=True, help="Output .safetensors file")
p.add_argument("--target", required=True, choices=["diffusers", "kohya"], help="Conversion direction")
p.add_argument("--diffusers_prefix", default="diffusion_model", help="Prefix for Diffusers (default: diffusion_model)")
p.add_argument("--adapter_name", default="default", help="Adapter name for Diffusers (default: default)")
return p.parse_args()
def main():
args = parse_args()
convert(args.input, args.output, args.target, args.diffusers_prefix, args.adapter_name)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment