Last active
November 7, 2025 07:08
-
-
Save tori29umai0123/2fb47a0eb3d9e1e55002b4beb5292021 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import logging | |
| import torch | |
| from safetensors import safe_open | |
| from safetensors.torch import load_file, save_file | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| QWEN_IMAGE_KEYS = [ | |
| "time_text_embed.timestep_embedder.linear_1", | |
| "time_text_embed.timestep_embedder.linear_2", | |
| "txt_norm", | |
| "img_in", | |
| "txt_in", | |
| "transformer_blocks.*.img_mod.1", | |
| "transformer_blocks.*.attn.norm_q", | |
| "transformer_blocks.*.attn.norm_k", | |
| "transformer_blocks.*.attn.to_q", | |
| "transformer_blocks.*.attn.to_k", | |
| "transformer_blocks.*.attn.to_v", | |
| "transformer_blocks.*.attn.add_k_proj", | |
| "transformer_blocks.*.attn.add_v_proj", | |
| "transformer_blocks.*.attn.add_q_proj", | |
| "transformer_blocks.*.attn.to_out.0", | |
| "transformer_blocks.*.attn.to_add_out", | |
| "transformer_blocks.*.attn.norm_added_q", | |
| "transformer_blocks.*.attn.norm_added_k", | |
| "transformer_blocks.*.img_mlp.net.0.proj", | |
| "transformer_blocks.*.img_mlp.net.2", | |
| "transformer_blocks.*.txt_mod.1", | |
| "transformer_blocks.*.txt_mlp.net.0.proj", | |
| "transformer_blocks.*.txt_mlp.net.2", | |
| "norm_out.linear", | |
| "proj_out", | |
| ] | |
| # ------------------------------------------------------------------------------ | |
| # Diffusers → Kohya (default) | |
| # ------------------------------------------------------------------------------ | |
| def convert_from_diffusers(prefix, weights_sd): | |
| new_weights_sd = {} | |
| lora_dims = {} | |
| for key, weight in weights_sd.items(): | |
| # 対応prefixを柔軟に認識 | |
| if key.startswith("diffusion_model."): | |
| key_body = key[len("diffusion_model.") :] | |
| elif key.startswith("transformer."): | |
| key_body = key[len("transformer.") :] | |
| elif key.startswith("unet."): | |
| key_body = key[len("unet.") :] | |
| else: | |
| logger.warning(f"Unexpected key: {key}") | |
| continue | |
| new_key = f"{prefix}{key_body}".replace(".", "_") | |
| new_key = ( | |
| new_key.replace("_lora_A_", ".lora_down.") | |
| .replace("_lora_B_", ".lora_up.") | |
| .replace("_lora_down_", ".lora_down.") | |
| .replace("_lora_up_", ".lora_up.") | |
| ) | |
| if new_key.endswith("_alpha"): | |
| new_key = new_key.replace("_alpha", ".alpha") | |
| new_weights_sd[new_key] = weight | |
| if "lora_down" in new_key: | |
| lora_name = new_key.split(".")[0] | |
| lora_dims[lora_name] = weight.shape[0] | |
| # alpha追加 | |
| for name, dim in lora_dims.items(): | |
| alpha_key = f"{name}.alpha" | |
| if alpha_key not in new_weights_sd: | |
| new_weights_sd[alpha_key] = torch.tensor(dim) | |
| return new_weights_sd | |
| # ------------------------------------------------------------------------------ | |
| # Kohya (default) → Diffusers | |
| # ------------------------------------------------------------------------------ | |
| def convert_to_diffusers(prefix, diffusers_prefix, adapter_name, weights_sd): | |
| diffusers_prefix_with_dot = f"{diffusers_prefix}." if diffusers_prefix else "" | |
| adapter_name = adapter_name or "default" | |
| lora_name_to_module = {} | |
| for key in QWEN_IMAGE_KEYS: | |
| base = key.replace(".", "_") | |
| if "*" not in key: | |
| lora_name_to_module[prefix + base] = key | |
| else: | |
| for i in range(100): | |
| lora_name_to_module[prefix + base.replace("*", str(i))] = key.replace("*", str(i)) | |
| # alphaを収集 | |
| lora_alphas = { | |
| key.split(".")[0]: weight | |
| for key, weight in weights_sd.items() | |
| if key.startswith(prefix) and "alpha" in key | |
| } | |
| new_sd = {} | |
| for key, weight in weights_sd.items(): | |
| if not key.startswith(prefix) or "alpha" in key: | |
| continue | |
| lora_name = key.split(".", 1)[0] | |
| module_name = lora_name_to_module.get( | |
| lora_name, | |
| lora_name[len(prefix) :].replace("_", "."), | |
| ) | |
| if "lora_down" in key: | |
| new_key = f"{diffusers_prefix_with_dot}{module_name}.lora_A.{adapter_name}.weight" | |
| dim = weight.shape[0] | |
| elif "lora_up" in key: | |
| new_key = f"{diffusers_prefix_with_dot}{module_name}.lora_B.{adapter_name}.weight" | |
| dim = weight.shape[1] | |
| else: | |
| continue | |
| # alphaスケーリング | |
| if lora_name in lora_alphas: | |
| scale = (lora_alphas[lora_name] / dim).sqrt() | |
| weight = weight * scale | |
| new_sd[new_key] = weight | |
| return new_sd | |
| # ------------------------------------------------------------------------------ | |
| # メイン処理 | |
| # ------------------------------------------------------------------------------ | |
| def convert(input_file, output_file, target_format, diffusers_prefix, adapter_name): | |
| logger.info(f"Loading {input_file}") | |
| weights_sd = load_file(input_file) | |
| with safe_open(input_file, framework="pt") as f: | |
| metadata = dict(f.metadata() or {}) | |
| prefix = "lora_unet_" | |
| if target_format == "kohya": | |
| new_weights = convert_from_diffusers(prefix, weights_sd) | |
| elif target_format == "diffusers": | |
| new_weights = convert_to_diffusers(prefix, diffusers_prefix, adapter_name, weights_sd) | |
| else: | |
| raise ValueError(f"Invalid target format: {target_format}") | |
| logger.info(f"Saving {output_file}") | |
| save_file(new_weights, output_file, metadata) | |
| logger.info("✅ Conversion complete.") | |
| def parse_args(): | |
| p = argparse.ArgumentParser(description="Convert LoRA between Kohya and Diffusers format") | |
| p.add_argument("--input", required=True, help="Input .safetensors file") | |
| p.add_argument("--output", required=True, help="Output .safetensors file") | |
| p.add_argument("--target", required=True, choices=["diffusers", "kohya"], help="Conversion direction") | |
| p.add_argument("--diffusers_prefix", default="diffusion_model", help="Prefix for Diffusers (default: diffusion_model)") | |
| p.add_argument("--adapter_name", default="default", help="Adapter name for Diffusers (default: default)") | |
| return p.parse_args() | |
| def main(): | |
| args = parse_args() | |
| convert(args.input, args.output, args.target, args.diffusers_prefix, args.adapter_name) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment