Last active
November 30, 2024 16:31
-
-
Save adhikjoshi/2c6da89cbcd7a6a3344d3081ccd1dda0 to your computer and use it in GitHub Desktop.
Lycoris Inference
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import lycoris_inference | |
import lora | |
from diffusers import DiffusionPipeline | |
import torch | |
from safetensors import safe_open | |
from safetensors.torch import load_file | |
import time | |
import json | |
# load SDXL pipeline | |
pipe = DiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True | |
).to("cuda") | |
# loha link : https://civitai.com/models/111594/sd-xl09-loha-pearly-gates-concept | |
# lora link : https://civitai.com/models/112904/arcane-style-lora-xl09 | |
# lycoris link : https://civitai.com/models/108011/fcbodybuildingxl-10-for-sdxl | |
#lora_model = "loha.safetensors" | |
lora_model = "lycoris.safetensors" | |
#lora_model = "sdxl.safetensors" | |
lora_strength = 1 | |
#pipe.load_lora_weights(".", weight_name=lora_model) | |
weights_sd = safe_open(lora_model, framework="pt") | |
network_args = weights_sd.metadata() | |
print(network_args) | |
weights_sd = None | |
weights_sd = load_file(lora_model) | |
try: | |
ss_network_args_dict = json.loads( | |
network_args['ss_network_args']) | |
if 'algo' in ss_network_args_dict: | |
algo = ss_network_args_dict['algo'] | |
except Exception as e: | |
try: | |
algo = network_args['ss_network_module'] | |
if algo == "networks.lora": | |
algo = "lora" | |
except Exception as e: | |
algo = "lora" | |
print(e) | |
print("Error: could not load ss_network_args") | |
if algo == "lora": | |
pipe.load_lora_weights(".", weight_name=lora_model) | |
#network = lora | |
else: | |
network = lycoris_inference | |
# for SDXL two text_encoders | |
network, weights_sd = network.create_network_from_weights(multiplier =lora_strength,file= "", vae = pipe.vae, text_encoder = [pipe.text_encoder, pipe.text_encoder_2],unet = pipe.unet, weights_sd = weights_sd,for_inference=True, algo= algo ) | |
network.apply_to() | |
info = network.load_state_dict(weights_sd, False) | |
network.to("cuda", dtype=torch.float16) | |
# create an image | |
generator = torch.Generator("cuda").manual_seed(0) | |
prompt = "pearlygates, 1girl, solo, scenery, long hair, dress, cloudy sky, standing" | |
image = pipe(prompt=prompt,generator=generator).images[0] | |
image.save(time.strftime("%Y%m%d_%H%M%S") + ".png") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from lycoris import kohya | |
from lycoris.modules import locon, loha, lokr | |
kohya.LycorisNetwork.UNET_TARGET_REPLACE_MODULE.remove("Attention") | |
class LokrModule(lokr.LokrModule): | |
def make_weight(self): | |
org_weight = self.org_module[0].weight.to(torch.float) | |
up = self.lora_up.weight.to( | |
device=org_weight.device, dtype=org_weight.dtype) | |
down = self.lora_down.weight.to( | |
device=org_weight.device, dtype=org_weight.dtype | |
) | |
if self.cp: | |
mid = self.lora_mid.weight.to( | |
device=org_weight.device, dtype=org_weight.dtype | |
) | |
up = up.reshape(up.size(0), up.size(1)) | |
down = down.reshape(down.size(0), down.size(1)) | |
weight = torch.einsum( | |
"i j k l, i p, j r -> p r k l", mid, up, down) | |
else: | |
weight = up.reshape( | |
up.size(0), -1) @ down.reshape(down.size(0), -1) | |
return weight.reshape(org_weight.shape) * self.scale | |
def merge_to(self, *args): | |
org_weight = self.org_module[0].weight | |
weight = self.make_weight() * self.multiplier | |
merged_weight = org_weight + weight.to(org_weight.dtype) | |
org_weight.copy_(merged_weight) | |
class LoConModule(locon.LoConModule): | |
def make_weight(self): | |
org_weight = self.org_module[0].weight.to(torch.float) | |
up = self.lora_up.weight.to( | |
device=org_weight.device, dtype=org_weight.dtype) | |
down = self.lora_down.weight.to( | |
device=org_weight.device, dtype=org_weight.dtype | |
) | |
if self.cp: | |
mid = self.lora_mid.weight.to( | |
device=org_weight.device, dtype=org_weight.dtype | |
) | |
up = up.reshape(up.size(0), up.size(1)) | |
down = down.reshape(down.size(0), down.size(1)) | |
weight = torch.einsum( | |
"m n w h, i m, n j -> i j w h", mid, up, down) | |
else: | |
weight = up.reshape( | |
up.size(0), -1) @ down.reshape(down.size(0), -1) | |
return weight.reshape(org_weight.shape) * self.scale | |
def merge_to(self): | |
org_weight = self.org_module[0].weight | |
weight = self.make_weight() * self.multiplier | |
org_weight.copy_(org_weight + weight.to(org_weight.dtype)) | |
class LohaModule(loha.LohaModule): | |
def make_weight(self): | |
org_weight = self.org_module[0].weight.to(torch.float) | |
w1a = self.hada_w1_a.to( | |
device=org_weight.device, dtype=org_weight.dtype) | |
w1b = self.hada_w1_b.to( | |
device=org_weight.device, dtype=org_weight.dtype) | |
w2a = self.hada_w2_a.to( | |
device=org_weight.device, dtype=org_weight.dtype) | |
w2b = self.hada_w2_b.to( | |
device=org_weight.device, dtype=org_weight.dtype) | |
if self.cp: | |
t1 = self.hada_t1.to(device=org_weight.device, | |
dtype=org_weight.dtype) | |
t2 = self.hada_t2.to(device=org_weight.device, | |
dtype=org_weight.dtype) | |
weight_1 = torch.einsum("i j k l, j r -> i r k l", t1, w1b) | |
weight_1 = torch.einsum("i j k l, i r -> r j k l", weight_1, w1a) | |
weight_2 = torch.einsum("i j k l, j r -> i r k l", t2, w2b) | |
weight_2 = torch.einsum("i j k l, i r -> r j k l", weight_2, w2a) | |
else: | |
weight_1 = w1a @ w1b | |
weight_2 = w2a @ w2b | |
return (weight_1 * weight_2).reshape(org_weight.shape) * self.scale | |
def merge_to(self): | |
org_weight = self.org_module[0].weight | |
weight = self.make_weight() * self.multiplier | |
org_weight.copy_(org_weight + weight.to(org_weight.dtype)) | |
def get_metadata(algo: str, weight): | |
if algo == "lora": | |
use_cp = False | |
conv_alpha = None | |
conv_lora_dim = None | |
lora_alpha = None | |
lora_dim = None | |
for key, value in weight.items(): | |
if key.endswith("alpha"): | |
base_key = key[:-6] | |
def get_dim(): | |
lora_up = weight[f"{base_key}.lora_up.weight"].size()[1] | |
lora_down = weight[f"{base_key}.lora_down.weight"].size()[ | |
0] | |
assert ( | |
lora_up == lora_down | |
), "lora_up and lora_down must be same size" | |
return lora_up | |
if any([x for x in ["conv", "conv1", "conv2"] if base_key.endswith(x)]): | |
conv_alpha = int(value) | |
conv_lora_dim = get_dim() | |
else: | |
lora_alpha = int(value) | |
lora_dim = get_dim() | |
if f"{base_key}.lora_mid.weight" in weight: | |
use_cp = True | |
return conv_alpha, conv_lora_dim, lora_alpha, lora_dim, {"use_cp": use_cp} | |
elif algo == "loha": | |
use_cp = False | |
conv_alpha = None | |
conv_lora_dim = None | |
lora_alpha = None | |
lora_dim = None | |
for key, value in weight.items(): | |
if key.endswith("alpha"): | |
base_key = key[:-6] | |
def get_dim(): | |
hada_w1_b = weight[f"{base_key}.hada_w1_b"].size()[0] | |
hada_w2_b = weight[f"{base_key}.hada_w2_b"].size()[0] | |
assert ( | |
hada_w1_b == hada_w2_b | |
), "hada_w1_b and hada_w2_b must be same size" | |
return hada_w1_b | |
if any([x for x in ["conv", "conv1", "conv2"] if base_key.endswith(x)]): | |
conv_alpha = int(value) | |
conv_lora_dim = get_dim() | |
else: | |
lora_alpha = int(value) | |
lora_dim = get_dim() | |
if f"{base_key}.hada_t1" in weight and f"{base_key}.hada_t2" in weight: | |
use_cp = True | |
return conv_alpha, conv_lora_dim, lora_alpha, lora_dim, {"use_cp": use_cp} | |
elif algo == "lokr": | |
use_cp = False | |
conv_alpha = None | |
conv_lora_dim = None | |
lora_alpha = None | |
lora_dim = None | |
for key, value in weight.items(): | |
if key.endswith("alpha"): | |
base_key = key[:-6] | |
def get_dim(): | |
return None | |
if any([x for x in ["conv", "conv1", "conv2"] if base_key.endswith(x)]): | |
conv_alpha = int(value) | |
conv_lora_dim = get_dim() | |
else: | |
lora_alpha = int(value) | |
lora_dim = get_dim() | |
if f"{base_key}.lora_mid.weight" in weight: | |
use_cp = True | |
# Additional layers | |
lora_layers = [ | |
"mlp_fc1", "mlp_fc2", "self_attn_k_proj", | |
"self_attn_out_proj", "self_attn_q_proj", "self_attn_v_proj" | |
] | |
for lora_layer in lora_layers: | |
if f"lora_te2_text_model_encoder_layers_9_{lora_layer}.alpha" in weight: | |
lora_alpha = int( | |
weight[f"lora_te2_text_model_encoder_layers_9_{lora_layer}.alpha"]) | |
lora_dim = weight[f"lora_te2_text_model_encoder_layers_9_{lora_layer}.lokr_w1"].size()[ | |
1] | |
return conv_alpha, conv_lora_dim, lora_alpha, lora_dim, {"use_cp": use_cp} | |
def create_network_from_weights( | |
multiplier: float, | |
file: str, | |
vae, | |
text_encoder, | |
unet, | |
algo=None, | |
weights_sd: torch.Tensor = None, | |
**kwargs, | |
): | |
apply_unet = None | |
apply_te = None | |
additional_kwargs = {} | |
print(algo) | |
for key in weights_sd.keys(): | |
if key.startswith("lora_unet"): | |
apply_unet = True | |
elif key.startswith("lora_te"): | |
apply_te = True | |
if algo is None: | |
if "lora_up" in key or "lora_down" in key: | |
algo = "lora" | |
elif "hada" in key: | |
algo = "loha" | |
if apply_unet is not None and apply_te is not None and algo is not None: | |
break | |
if algo is None: | |
raise ValueError("Could not determine network module") | |
( | |
conv_alpha, | |
conv_dim, | |
lora_alpha, | |
lora_dim, | |
additional_kwargs, | |
) = get_metadata(algo, weights_sd) | |
if lora_dim is None or lora_alpha is None: | |
lora_dim = 0 | |
lora_alpha = 0 | |
if conv_dim is None or conv_alpha is None: | |
conv_dim = 0 | |
conv_alpha = 0 | |
network_module = { | |
"lora": LoConModule, | |
"locon": LoConModule, | |
"loha": LohaModule, | |
# "ia3": IA3Module, | |
"lokr": LokrModule, | |
# "dylora": DyLoraModule, | |
# "glora": GLoRAModule, | |
}[algo] | |
network = LycorisNetwork( | |
text_encoder, | |
unet, | |
multiplier=multiplier, | |
lora_dim=lora_dim, | |
conv_lora_dim=int(conv_dim), | |
alpha=lora_alpha, | |
conv_alpha=conv_alpha, | |
network_module=network_module, | |
apply_unet=apply_unet, | |
apply_te=apply_te, | |
weights_sd=weights_sd, | |
**additional_kwargs, | |
) | |
return network, weights_sd | |
class LycorisNetwork(kohya.LycorisNetwork): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.apply_unet = kwargs.get("apply_unet", True) | |
self.apply_te = kwargs.get("apply_te", True) | |
self.weights_sd = kwargs.get("weights_sd", None) | |
if self.apply_unet: | |
for lora in self.unet_loras: | |
self.add_module(lora.lora_name, lora) | |
if self.apply_te: | |
for lora in self.text_encoder_loras: | |
self.add_module(lora.lora_name, lora) | |
for lora in self.text_encoder_loras + self.unet_loras: | |
org_module = lora.org_module[0] | |
if not hasattr(org_module, "_lora_org_forward"): | |
setattr(org_module, "_lora_org_forward", org_module.forward) | |
if not hasattr(org_module, "_lora_org_weight"): | |
setattr(org_module, "_lora_org_weight", | |
org_module.weight.clone().cpu()) | |
def apply_to(self): | |
apply_text_encoder = self.apply_te | |
apply_unet = self.apply_unet | |
if self.weights_sd: | |
weights_has_text_encoder = weights_has_unet = False | |
for key in self.weights_sd.keys(): | |
if key.startswith(LycorisNetwork.LORA_PREFIX_TEXT_ENCODER): | |
weights_has_text_encoder = True | |
elif key.startswith(LycorisNetwork.LORA_PREFIX_UNET): | |
weights_has_unet = True | |
if apply_text_encoder is None: | |
apply_text_encoder = weights_has_text_encoder | |
else: | |
assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています" | |
if apply_unet is None: | |
apply_unet = weights_has_unet | |
else: | |
assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています" | |
else: | |
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set" | |
if apply_text_encoder: | |
print("enable LyCORIS for text encoder") | |
else: | |
self.text_encoder_loras = [] | |
if apply_unet: | |
print("enable LyCORIS for U-Net") | |
else: | |
self.unet_loras = [] | |
for lora in self.text_encoder_loras + self.unet_loras: | |
lora.apply_to() | |
self.add_module(lora.lora_name, lora) | |
if self.weights_sd: | |
info = self.load_state_dict(self.weights_sd, False) | |
print(f"weights are loaded") | |
def merge_to(self): | |
for lora in self.text_encoder_loras + self.unet_loras: | |
lora.merge_to() | |
def restore(self, *args): | |
for lora in self.text_encoder_loras + self.unet_loras: | |
org_module = lora.org_module[0] | |
if hasattr(org_module, "_lora_org_forward"): | |
org_module.forward = org_module._lora_org_forward | |
del org_module._lora_org_forward | |
if hasattr(org_module, "_lora_org_weight"): | |
org_module.weight.copy_(org_module._lora_org_weight) | |
del org_module._lora_org_weight |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
anyio==3.7.1 | |
argon2-cffi==21.3.0 | |
argon2-cffi-bindings==21.2.0 | |
arrow==1.2.3 | |
asttokens==2.2.1 | |
async-lru==2.0.3 | |
attrs==23.1.0 | |
Babel==2.12.1 | |
backcall==0.2.0 | |
beautifulsoup4==4.12.2 | |
bleach==6.0.0 | |
blinker==1.4 | |
certifi==2022.12.7 | |
cffi==1.15.1 | |
charset-normalizer==2.1.1 | |
cmake==3.25.0 | |
comm==0.1.3 | |
cryptography==3.4.8 | |
dbus-python==1.2.18 | |
debugpy==1.6.7 | |
decorator==5.1.1 | |
defusedxml==0.7.1 | |
diffusers==0.19.2 | |
distro==1.7.0 | |
einops==0.6.1 | |
exceptiongroup==1.1.2 | |
executing==1.2.0 | |
fastjsonschema==2.17.1 | |
filelock==3.9.0 | |
fqdn==1.5.1 | |
fsspec==2023.6.0 | |
httplib2==0.20.2 | |
huggingface-hub==0.16.4 | |
idna==3.4 | |
importlib-metadata==4.6.4 | |
invisible-watermark==0.2.0 | |
ipykernel==6.24.0 | |
ipython==8.14.0 | |
ipython-genutils==0.2.0 | |
ipywidgets==8.0.7 | |
isoduration==20.11.0 | |
jedi==0.18.2 | |
jeepney==0.7.1 | |
Jinja2==3.1.2 | |
json5==0.9.14 | |
jsonpointer==2.4 | |
jsonschema==4.18.0 | |
jsonschema-specifications==2023.6.1 | |
jupyter-archive==3.3.4 | |
jupyter-contrib-core==0.4.2 | |
jupyter-contrib-nbextensions==0.7.0 | |
jupyter-events==0.6.3 | |
jupyter-highlight-selected-word==0.2.0 | |
jupyter-lsp==2.2.0 | |
jupyter-nbextensions-configurator==0.6.3 | |
jupyter_client==8.3.0 | |
jupyter_core==5.3.1 | |
jupyter_server==2.7.0 | |
jupyter_server_terminals==0.4.4 | |
jupyterlab==4.0.2 | |
jupyterlab-pygments==0.2.2 | |
jupyterlab-widgets==3.0.8 | |
jupyterlab_server==2.23.0 | |
keyring==23.5.0 | |
launchpadlib==1.10.16 | |
lazr.restfulclient==0.14.4 | |
lazr.uri==1.0.6 | |
lit==15.0.7 | |
losalina==1.0.0 | |
lxml==4.9.3 | |
lycoris-lora==1.8.0 | |
MarkupSafe==2.1.2 | |
matplotlib-inline==0.1.6 | |
mistune==3.0.1 | |
more-itertools==8.10.0 | |
mpmath==1.2.1 | |
nbclassic==1.0.0 | |
nbclient==0.8.0 | |
nbconvert==7.6.0 | |
nbformat==5.9.1 | |
nest-asyncio==1.5.6 | |
networkx==3.0 | |
notebook==6.5.4 | |
notebook_shim==0.2.3 | |
numpy==1.24.1 | |
oauthlib==3.2.0 | |
opencv-python==4.8.0.74 | |
overrides==7.3.1 | |
packaging==23.1 | |
pandocfilters==1.5.0 | |
parso==0.8.3 | |
pexpect==4.8.0 | |
pickleshare==0.7.5 | |
Pillow==9.3.0 | |
platformdirs==3.8.1 | |
prometheus-client==0.17.0 | |
prompt-toolkit==3.0.39 | |
psutil==5.9.5 | |
ptyprocess==0.7.0 | |
pure-eval==0.2.2 | |
pycparser==2.21 | |
Pygments==2.15.1 | |
PyGObject==3.42.1 | |
PyJWT==2.3.0 | |
pyparsing==2.4.7 | |
python-apt==2.4.0+ubuntu1 | |
python-dateutil==2.8.2 | |
python-json-logger==2.0.7 | |
PyWavelets==1.4.1 | |
PyYAML==6.0 | |
pyzmq==25.1.0 | |
referencing==0.29.1 | |
regex==2023.6.3 | |
requests==2.28.1 | |
rfc3339-validator==0.1.4 | |
rfc3986-validator==0.1.1 | |
rpds-py==0.8.10 | |
safetensors==0.3.1 | |
SecretStorage==3.3.1 | |
Send2Trash==1.8.2 | |
six==1.16.0 | |
sniffio==1.3.0 | |
soupsieve==2.4.1 | |
stack-data==0.6.2 | |
sympy==1.11.1 | |
terminado==0.17.1 | |
timm==0.9.2 | |
tinycss2==1.2.1 | |
tokenizers==0.13.3 | |
tomli==2.0.1 | |
torch==2.0.1+cu118 | |
torchaudio==2.0.2+cu118 | |
torchvision==0.15.2+cu118 | |
tornado==6.3.2 | |
tqdm==4.65.0 | |
traitlets==5.9.0 | |
transformers==4.31.0 | |
triton==2.0.0 | |
typing_extensions==4.4.0 | |
uri-template==1.3.0 | |
urllib3==1.26.13 | |
wadllib==1.3.6 | |
wcwidth==0.2.6 | |
webcolors==1.13 | |
webencodings==0.5.1 | |
websocket-client==1.6.1 | |
widgetsnbextension==4.0.8 | |
zipp==1.0.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment