Created
May 8, 2023 15:49
-
-
Save takuma104/e38d683d72b1e448b8d9b3835f7cfa44 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import safetensors | |
import torch | |
from diffusers import DiffusionPipeline | |
""" | |
Kohya's LoRA format Loader for Diffusers | |
Usage: | |
```py | |
# An usual Diffusers' setup | |
import torch | |
from diffusers import StableDiffusionPipeline | |
pipe = StableDiffusionPipeline.from_pretrained('...', | |
torch_dtype=torch.float16).to('cuda') | |
# Import this module | |
import kohya_lora_loader | |
# Install LoRA hook. This append apply_loar and remove_loar methods to the pipe. | |
kohya_lora_loader.install_lora_hook(pipe) | |
# Load 'lora1.safetensors' file and apply | |
lora1 = pipe.apply_lora('lora1.safetensors', 1.0) | |
# You can change alpha | |
lora1.alpha = 0.5 | |
# Load 'lora2.safetensors' file and apply | |
lora2 = pipe.apply_lora('lora2.safetensors', 1.0) | |
# Generate image with lora1 and lora2 applied | |
pipe(...).images[0] | |
# Remove lora2 | |
pipe.remove_lora(lora2) | |
# Generate image with lora1 applied | |
pipe(...).images[0] | |
# Uninstall LoRA hook | |
kohya_lora_loader.uninstall_lora_hook(pipe) | |
# Generate image with none LoRA applied | |
pipe(...).images[0] | |
``` | |
""" | |
# modified from https://github.com/kohya-ss/sd-scripts/blob/ad5f318d066c52e5b27306b399bc87e41f2eef2b/networks/lora.py#L17 | |
class LoRAModule(torch.nn.Module): | |
def __init__( | |
self, org_module: torch.nn.Module, lora_dim=4, alpha=1.0, multiplier=1.0 | |
): | |
"""if alpha == 0 or None, alpha is rank (no scaling).""" | |
super().__init__() | |
if org_module.__class__.__name__ == "Conv2d": | |
in_dim = org_module.in_channels | |
out_dim = org_module.out_channels | |
else: | |
in_dim = org_module.in_features | |
out_dim = org_module.out_features | |
self.lora_dim = lora_dim | |
if org_module.__class__.__name__ == "Conv2d": | |
kernel_size = org_module.kernel_size | |
stride = org_module.stride | |
padding = org_module.padding | |
self.lora_down = torch.nn.Conv2d( | |
in_dim, self.lora_dim, kernel_size, stride, padding, bias=False | |
) | |
self.lora_up = torch.nn.Conv2d( | |
self.lora_dim, out_dim, (1, 1), (1, 1), bias=False | |
) | |
else: | |
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) | |
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) | |
if alpha is None or alpha == 0: | |
self.alpha = self.lora_dim | |
else: | |
if type(alpha) == torch.Tensor: | |
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error | |
self.register_buffer("alpha", torch.tensor(alpha)) # Treatable as a constant. | |
# same as microsoft's | |
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) | |
torch.nn.init.zeros_(self.lora_up.weight) | |
self.multiplier = multiplier | |
def forward(self, x): | |
scale = self.alpha / self.lora_dim | |
return self.multiplier * scale * self.lora_up(self.lora_down(x)) | |
class LoRAModuleContainer(torch.nn.Module): | |
def __init__(self, hooks, state_dict, multiplier): | |
super().__init__() | |
self.multiplier = multiplier | |
# Create LoRAModule from state_dict information | |
for key, value in state_dict.items(): | |
if "lora_down" in key: | |
lora_name = key.split(".")[0] | |
lora_dim = value.size()[0] | |
lora_name_alpha = key.split(".")[0] + '.alpha' | |
alpha = None | |
if lora_name_alpha in state_dict: | |
alpha = state_dict[lora_name_alpha].item() | |
hook = hooks[lora_name] | |
lora_module = LoRAModule( | |
hook.orig_module, lora_dim=lora_dim, alpha=alpha, multiplier=multiplier | |
) | |
self.register_module(lora_name, lora_module) | |
# Load whole LoRA weights | |
self.load_state_dict(state_dict) | |
# Register LoRAModule to LoRAHook | |
for name, module in self.named_modules(): | |
if module.__class__.__name__ == "LoRAModule": | |
hook = hooks[name] | |
hook.append_lora(module) | |
@property | |
def alpha(self): | |
return self.multiplier | |
@alpha.setter | |
def alpha(self, multiplier): | |
self.multiplier = multiplier | |
for name, module in self.named_modules(): | |
if module.__class__.__name__ == "LoRAModule": | |
module.multiplier = multiplier | |
def remove_from_hooks(self, hooks): | |
for name, module in self.named_modules(): | |
if module.__class__.__name__ == "LoRAModule": | |
hook = hooks[name] | |
hook.remove_lora(module) | |
del module | |
class LoRAHook(torch.nn.Module): | |
""" | |
replaces forward method of the original Linear, | |
instead of replacing the original Linear module. | |
""" | |
def __init__(self): | |
super().__init__() | |
self.lora_modules = [] | |
def install(self, orig_module): | |
assert not hasattr(self, "orig_module") | |
self.orig_module = orig_module | |
self.orig_forward = self.orig_module.forward | |
self.orig_module.forward = self.forward | |
def uninstall(self): | |
assert hasattr(self, "orig_module") | |
self.orig_module.forward = self.orig_forward | |
del self.orig_forward | |
del self.orig_module | |
def append_lora(self, lora_module): | |
self.lora_modules.append(lora_module) | |
def remove_lora(self, lora_module): | |
self.lora_modules.remove(lora_module) | |
def forward(self, x): | |
if len(self.lora_modules) == 0: | |
return self.orig_forward(x) | |
lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0) | |
return self.orig_forward(x) + lora | |
class LoRAHookInjector(object): | |
def __init__(self): | |
super().__init__() | |
self.hooks = {} | |
self.device = None | |
self.dtype = None | |
def _get_target_modules(self, root_module, prefix, target_replace_modules): | |
target_modules = [] | |
for name, module in root_module.named_modules(): | |
if ( | |
module.__class__.__name__ in target_replace_modules | |
and not "transformer_blocks" in name | |
): # to adapt latest diffusers: | |
for child_name, child_module in module.named_modules(): | |
is_linear = child_module.__class__.__name__ == "Linear" | |
is_conv2d = child_module.__class__.__name__ == "Conv2d" | |
if is_linear or is_conv2d: | |
lora_name = prefix + "." + name + "." + child_name | |
lora_name = lora_name.replace(".", "_") | |
target_modules.append((lora_name, child_module)) | |
return target_modules | |
def install_hooks(self, pipe): | |
"""Install LoRAHook to the pipe.""" | |
assert len(self.hooks) == 0 | |
text_encoder_targets = self._get_target_modules( | |
pipe.text_encoder, "lora_te", ["CLIPAttention", "CLIPMLP"] | |
) | |
unet_targets = self._get_target_modules( | |
pipe.unet, "lora_unet", ["Transformer2DModel", "Attention"] | |
) | |
for name, target_module in text_encoder_targets + unet_targets: | |
hook = LoRAHook() | |
hook.install(target_module) | |
self.hooks[name] = hook | |
self.device = pipe.device | |
self.dtype = pipe.unet.dtype | |
def uninstall_hooks(self): | |
"""Uninstall LoRAHook from the pipe.""" | |
for k, v in self.hooks.items(): | |
v.uninstall() | |
self.hooks = {} | |
def apply_lora(self, filename, alpha=1.0): | |
"""Load LoRA weights and apply LoRA to the pipe.""" | |
assert len(self.hooks) != 0 | |
state_dict = safetensors.torch.load_file(filename) | |
container = LoRAModuleContainer(self.hooks, state_dict, alpha) | |
container.to(self.device, self.dtype) | |
return container | |
def remove_lora(self, container): | |
"""Remove the individual LoRA from the pipe.""" | |
container.remove_from_hooks(self.hooks) | |
def install_lora_hook(pipe: DiffusionPipeline): | |
"""Install LoRAHook to the pipe.""" | |
assert not hasattr(pipe, "lora_injector") | |
assert not hasattr(pipe, "apply_lora") | |
assert not hasattr(pipe, "remove_lora") | |
injector = LoRAHookInjector() | |
injector.install_hooks(pipe) | |
pipe.lora_injector = injector | |
pipe.apply_lora = injector.apply_lora | |
pipe.remove_lora = injector.remove_lora | |
def uninstall_lora_hook(pipe: DiffusionPipeline): | |
"""Uninstall LoRAHook from the pipe.""" | |
pipe.lora_injector.uninstall_hooks() | |
del pipe.lora_injector | |
del pipe.apply_lora | |
del pipe.remove_lora |
I found out why it didn't work. It doesn't take into account the model's dtype and is always float32 from what I saw thus far
@takuma104 Can you add support for SDXL models?
@takuma104 do you think you'll be able to do Lycoris support?
For diffusers after #4147, consider adding the following module checkings if you still want to use this hook:
at line 197
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ in ["Linear", "LoRACompatibleLinear"]
is_conv2d = child_module.__class__.__name__ in ["Conv2d", "LoRACompatibleConv"]
and at line 60 & 69:
if org_module.__class__.__name__ in ["Conv2d", "LoRACompatibleConv"]:
also, add a default scale in LoRAHook.forward at line 176:
def forward(self, x, scale=1.0):
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hey, man. I have tried your code with a loaded safetensors from civitai. I got this error:
/mecomics-api/tester.py:36 in │
│ │
│ 33 buffer.seek(0) │
│ 34 image_bytes = buffer.read() │
│ 35 images = Image.open(BytesIO(image_bytes)) │
│ ❱ 36 image = pipe( │
│ 37 │ prompt=None, │
│ 38 │ negative_prompt=None, │
│ 39 │ prompt_embeds=promptE, │
│ │
│ /usr/local/lib/python3.8/site-packages/torch/utils/_contextlib.py:115 in decorate_context │
│ │
│ 112 │ @functools.wraps(func) │
│ 113 │ def decorate_context(*args, **kwargs): │
│ 114 │ │ with ctx_factory(): │
│ ❱ 115 │ │ │ return func(*args, **kwargs) │
│ 116 │ │
│ 117 │ return decorate_context │
│ 118 │
│ │
│ /mecomics-api/multiDiffusion.py:1064 in call │
│ │
│ 1061 │ │ │ │ │ │ │ guess_mode=False, │
│ 1062 │ │ │ │ │ │ │ return_dict=False, │
│ 1063 │ │ │ │ │ │ ) │
│ ❱ 1064 │ │ │ │ │ │ noise_pred=self.unet( │
│ 1065 │ │ │ │ │ │ │ latent_model_input, │
│ 1066 │ │ │ │ │ │ │ t, │
│ 1067 │ │ │ │ │ │ │ encoder_hidden_states=text_embeddings[i], │
│ │
│ /usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /home/alexblattnershalom/.local/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.p │
│ y:724 in forward │
│ │
│ 721 │ │ down_block_res_samples = (sample,) │
│ 722 │ │ for downsample_block in self.down_blocks: │
│ 723 │ │ │ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has │
│ ❱ 724 │ │ │ │ sample, res_samples = downsample_block( │
│ 725 │ │ │ │ │ hidden_states=sample, │
│ 726 │ │ │ │ │ temb=emb, │
│ 727 │ │ │ │ │ encoder_hidden_states=encoder_hidden_states, │
│ │
│ /usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /home/alexblattnershalom/.local/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py:8 │
│ 68 in forward │
│ │
│ 865 │ │ │ │ )[0] │
│ 866 │ │ │ else: │
│ 867 │ │ │ │ hidden_states = resnet(hidden_states, temb) │
│ ❱ 868 │ │ │ │ hidden_states = attn( │
│ 869 │ │ │ │ │ hidden_states, │
│ 870 │ │ │ │ │ encoder_hidden_states=encoder_hidden_states, │
│ 871 │ │ │ │ │ cross_attention_kwargs=cross_attention_kwargs, │
│ │
│ /usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /home/alexblattnershalom/.local/lib/python3.8/site-packages/diffusers/models/transformer_2d.py:2 │
│ 51 in forward │
│ │
│ 248 │ │ │ │
│ 249 │ │ │ hidden_states = self.norm(hidden_states) │
│ 250 │ │ │ if not self.use_linear_projection: │
│ ❱ 251 │ │ │ │ hidden_states = self.proj_in(hidden_states) │
│ 252 │ │ │ │ inner_dim = hidden_states.shape[1] │
│ 253 │ │ │ │ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height │
│ 254 │ │ │ else: │
│ │
│ /usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /mecomics-api/kohya_lora_loader.py:179 in forward │
│ │
│ 176 │ def forward(self, x): │
│ 177 │ │ if len(self.lora_modules) == 0: │
│ 178 │ │ │ return self.orig_forward(x) │
│ ❱ 179 │ │ lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0) │
│ 180 │ │ return self.orig_forward(x) + lora │
│ 181 │
│ 182 │
│ │
│ /mecomics-api/kohya_lora_loader.py:179 in │
│ │
│ 176 │ def forward(self, x): │
│ 177 │ │ if len(self.lora_modules) == 0: │
│ 178 │ │ │ return self.orig_forward(x) │
│ ❱ 179 │ │ lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0) │
│ 180 │ │ return self.orig_forward(x) + lora │
│ 181 │
│ 182 │
│ │
│ /usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /mecomics-api/kohya_lora_loader.py:98 in forward │
│ │
│ 95 │ │
│ 96 │ def forward(self, x): │
│ 97 │ │ scale = self.alpha / self.lora_dim │
│ ❱ 98 │ │ return self.multiplier * scale * self.lora_up(self.lora_down(x)) │
│ 99 │
│ 100 │
│ 101 class LoRAModuleContainer(torch.nn.Module): │
│ │
│ /usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501 in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /usr/local/lib/python3.8/site-packages/torch/nn/modules/conv.py:463 in forward │
│ │
│ 460 │ │ │ │ │ │ self.padding, self.dilation, self.groups) │
│ 461 │ │
│ 462 │ def forward(self, input: Tensor) -> Tensor: │
│ ❱ 463 │ │ return self._conv_forward(input, self.weight, self.bias) │
│ 464 │
│ 465 class Conv3d(_ConvNd): │
│ 466 │ doc = r"""Applies a 3D convolution over an input signal composed of several inpu │
│ │
│ /usr/local/lib/python3.8/site-packages/torch/nn/modules/conv.py:459 in _conv_forward │
│ │
│ 456 │ │ │ return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=sel │
│ 457 │ │ │ │ │ │ │ weight, bias, self.stride, │
│ 458 │ │ │ │ │ │ │ _pair(0), self.dilation, self.groups) │
│ ❱ 459 │ │ return F.conv2d(input, weight, bias, self.stride, │
│ 460 │ │ │ │ │ │ self.padding, self.dilation, self.groups) │
│ 461 │ │
│ 462 │ def forward(self, input: Tensor) -> Tensor: │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same
this is how I loaded the ckpt:
@takuma104