Skip to content

Instantly share code, notes, and snippets.

@arcanite24
Created November 8, 2023 01:19
Show Gist options
  • Save arcanite24/63e72dc01efe2fc94b42771e8a00ddb7 to your computer and use it in GitHub Desktop.
Save arcanite24/63e72dc01efe2fc94b42771e8a00ddb7 to your computer and use it in GitHub Desktop.
Convert Diffusers Safetensors to 1111
import argparse
from pathlib import Path
from diffusers import DiffusionPipeline
import torch
from safetensors.torch import save_file
# First, set up the argument parser
parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument("input_file", type=str, help="The input safe-tensors file")
parser.add_argument("output_file", type=str, help="The output safe-tensors file")
# Parse the arguments
args = parser.parse_args()
# Then, use the arguments in the script
input_file = args.input_file
output_file = args.output_file
pipe = DiffusionPipeline.from_pretrained(
"segmind/SSD-1B",
torch_dtype=torch.float16,
)
state_dict, network_alphas = pipe.lora_state_dict(
Path(input_file), local_files_only=True
)
LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1",
"mlp.fc2": "mlp_fc2",
"self_attn.k_proj": "self_attn_k_proj",
"self_attn.q_proj": "self_attn_q_proj",
"self_attn.v_proj": "self_attn_v_proj",
"self_attn.out_proj": "self_attn_out_proj",
"lora_linear_layer.down": "lora_down",
"lora_linear_layer.up": "lora_up",
}
LORA_UNET_MAP = {
"processor.to_q_lora.down": "to_q.lora_down",
"processor.to_q_lora.up": "to_q.lora_up",
"processor.to_k_lora.down": "to_k.lora_down",
"processor.to_k_lora.up": "to_k.lora_up",
"processor.to_v_lora.down": "to_v.lora_down",
"processor.to_v_lora.up": "to_v.lora_up",
"processor.to_out_lora.down": "to_out_0.lora_down",
"processor.to_out_lora.up": "to_out_0.lora_up",
"processor.to_q.alpha": "to_q.alpha",
"processor.to_k.alpha": "to_k.alpha",
"processor.to_v.alpha": "to_v.alpha",
}
webui_lora_state_dict = {}
for k, v in state_dict.items():
is_text_encoder = False
prefix = k.split(".")[0]
if prefix == "text_encoder":
k = k.replace("text_encoder", "lora_te1")
is_text_encoder = True
elif prefix == "text_encoder_2":
k = k.replace("text_encoder_2", "lora_te2")
is_text_encoder = True
elif prefix == "unet":
k = k.replace("unet", "lora_unet")
if is_text_encoder:
for map_k, map_v in LORA_CLIP_MAP.items():
k = k.replace(map_k, map_v)
else:
for map_k, map_v in LORA_UNET_MAP.items():
k = k.replace(map_k, map_v)
keep_dots = 0
if k.endswith(".alpha"):
keep_dots = 1
elif k.endswith(".weight"):
keep_dots = 2
parts = k.split(".")
k = "_".join(parts[:-keep_dots]) + "." + ".".join(parts[-keep_dots:])
webui_lora_state_dict[k] = v
save_file(webui_lora_state_dict, output_file)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment