Skip to content

Instantly share code, notes, and snippets.

@takuma104
Created May 11, 2023 14:43
Show Gist options
  • Save takuma104/4d37c583e62b04dc250541bae6291f93 to your computer and use it in GitHub Desktop.
Save takuma104/4d37c583e62b04dc250541bae6291f93 to your computer and use it in GitHub Desktop.
import torch
import sys
from safetensors.torch import load_file
from diffusers import StableDiffusionPipeline
state_dict = load_file('some_lora.safetensors')
new_state_dict = {}
for key, value in state_dict.items():
if "lora_down" in key:
lora_name = key.split(".")[0]
lora_dim = value.size()[0]
lora_name_up = lora_name + '.lora_up.weight'
if lora_name.startswith('lora_unet_'):
diffusers_name = key.replace('lora_unet_', '').replace('_', '.')
diffusers_name = diffusers_name.replace('down.blocks', 'down_blocks')
diffusers_name = diffusers_name.replace('mid.block', 'mid_block')
diffusers_name = diffusers_name.replace('up.blocks', 'up_blocks')
diffusers_name = diffusers_name.replace('transformer.blocks', 'transformer_blocks')
diffusers_name = diffusers_name.replace('to.q.lora', 'to_q_lora')
diffusers_name = diffusers_name.replace('to.k.lora', 'to_k_lora')
diffusers_name = diffusers_name.replace('to.v.lora', 'to_v_lora')
diffusers_name = diffusers_name.replace('to.out.0.lora', 'to_out_lora')
if 'transformer_blocks' in diffusers_name:
if 'attn1' in diffusers_name or 'attn2' in diffusers_name:
diffusers_name = diffusers_name.replace('attn1', 'attn1.processor')
diffusers_name = diffusers_name.replace('attn2', 'attn2.processor')
new_state_dict[diffusers_name] = value
new_state_dict[diffusers_name.replace('.down.','.up.')] = state_dict[lora_name_up]
pipe = StableDiffusionPipeline.from_pretrained('...').to('cuda')
pipe.unet.load_attn_procs(new_state_dict)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment