Created
November 8, 2023 01:19
-
-
Save arcanite24/63e72dc01efe2fc94b42771e8a00ddb7 to your computer and use it in GitHub Desktop.
Convert Diffusers Safetensors to 1111
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from pathlib import Path | |
from diffusers import DiffusionPipeline | |
import torch | |
from safetensors.torch import save_file | |
# First, set up the argument parser | |
parser = argparse.ArgumentParser(description="Process some integers.") | |
parser.add_argument("input_file", type=str, help="The input safe-tensors file") | |
parser.add_argument("output_file", type=str, help="The output safe-tensors file") | |
# Parse the arguments | |
args = parser.parse_args() | |
# Then, use the arguments in the script | |
input_file = args.input_file | |
output_file = args.output_file | |
pipe = DiffusionPipeline.from_pretrained( | |
"segmind/SSD-1B", | |
torch_dtype=torch.float16, | |
) | |
state_dict, network_alphas = pipe.lora_state_dict( | |
Path(input_file), local_files_only=True | |
) | |
LORA_CLIP_MAP = { | |
"mlp.fc1": "mlp_fc1", | |
"mlp.fc2": "mlp_fc2", | |
"self_attn.k_proj": "self_attn_k_proj", | |
"self_attn.q_proj": "self_attn_q_proj", | |
"self_attn.v_proj": "self_attn_v_proj", | |
"self_attn.out_proj": "self_attn_out_proj", | |
"lora_linear_layer.down": "lora_down", | |
"lora_linear_layer.up": "lora_up", | |
} | |
LORA_UNET_MAP = { | |
"processor.to_q_lora.down": "to_q.lora_down", | |
"processor.to_q_lora.up": "to_q.lora_up", | |
"processor.to_k_lora.down": "to_k.lora_down", | |
"processor.to_k_lora.up": "to_k.lora_up", | |
"processor.to_v_lora.down": "to_v.lora_down", | |
"processor.to_v_lora.up": "to_v.lora_up", | |
"processor.to_out_lora.down": "to_out_0.lora_down", | |
"processor.to_out_lora.up": "to_out_0.lora_up", | |
"processor.to_q.alpha": "to_q.alpha", | |
"processor.to_k.alpha": "to_k.alpha", | |
"processor.to_v.alpha": "to_v.alpha", | |
} | |
webui_lora_state_dict = {} | |
for k, v in state_dict.items(): | |
is_text_encoder = False | |
prefix = k.split(".")[0] | |
if prefix == "text_encoder": | |
k = k.replace("text_encoder", "lora_te1") | |
is_text_encoder = True | |
elif prefix == "text_encoder_2": | |
k = k.replace("text_encoder_2", "lora_te2") | |
is_text_encoder = True | |
elif prefix == "unet": | |
k = k.replace("unet", "lora_unet") | |
if is_text_encoder: | |
for map_k, map_v in LORA_CLIP_MAP.items(): | |
k = k.replace(map_k, map_v) | |
else: | |
for map_k, map_v in LORA_UNET_MAP.items(): | |
k = k.replace(map_k, map_v) | |
keep_dots = 0 | |
if k.endswith(".alpha"): | |
keep_dots = 1 | |
elif k.endswith(".weight"): | |
keep_dots = 2 | |
parts = k.split(".") | |
k = "_".join(parts[:-keep_dots]) + "." + ".".join(parts[-keep_dots:]) | |
webui_lora_state_dict[k] = v | |
save_file(webui_lora_state_dict, output_file) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment