Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Created September 21, 2023 12:31
Show Gist options
  • Save laksjdjf/a5f1278137f9a7e979dc875784efaaf7 to your computer and use it in GitHub Desktop.
Save laksjdjf/a5f1278137f9a7e979dc875784efaaf7 to your computer and use it in GitHub Desktop.
import torch
import comfy
import copy
def chunk_or_none(x, chunk_size, index):
if x is None:
return None
return x.chunk(chunk_size, dim=0)[index]
def chunk_or_none_for_control(x, chunk_size, index):
if x is None:
return None
else:
control = {}
if "input" in x:
control["input"] = [s.chunk(chunk_size, dim=0)[index] for s in x["input"]]
if "middle" in x:
control["middle"] = [s.chunk(chunk_size, dim=0)[index] for s in x["middle"]]
if "output" in x:
control["output"] = [s.chunk(chunk_size, dim=0)[index] for s in x["output"]]
return control
def add_patches(regional_lora, patches, number):
if number in regional_lora:
for key in patches:
if key in regional_lora[number]:
regional_lora[number][key].extend(patches[key])
else:
regional_lora[number][key] = patches[key]
else:
regional_lora[number] = {}
for key in patches:
regional_lora[number][key] = patches[key]
return regional_lora
class RegionalLoRALoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"prompt_id": ("INT", {
"default": 1,
"min": 0,
"max": 10,
"step": 1,
"display": "number"
}),
},
"optional": {
"regional_lora": ("REGIONAL_LORA", ),
},
}
RETURN_TYPES = ("REGIONAL_LORA", )
FUNCTION = "apply"
CATEGORY = "regional_lora"
def apply(self, model, prompt_id, regional_lora=None):
if regional_lora is None:
new_regional_lora = {}
else:
new_regional_lora = copy.copy(regional_lora)
new_regional_lora = add_patches(new_regional_lora, model.patches, prompt_id)
return (new_regional_lora, )
class ApplyRegionalLoRA:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"regional_lora": ("REGIONAL_LORA", ),
"num_prompts": ("INT", {
"default": 1,
"min": 0,
"max": 10,
"step": 1,
"display": "number"
}),
},
}
RETURN_TYPES = ("MODEL", )
FUNCTION = "apply"
CATEGORY = "regional_lora"
def apply(self, model, regional_lora, num_prompts):
self.num_prompts = num_prompts
new_model = model.clone()
def apply_regional(model_function, kwargs):
input = kwargs["input"]
timestep = kwargs["timestep"]
c_concat=kwargs["c"].get("c_concat", None)
c_crossattn = kwargs["c"].get("c_crossattn", None)
c_adm = kwargs["c"].get("c_adm", None)
control = kwargs["c"].get("control", None)
transformer_options = kwargs["c"].get("transformer_options", None)
retvals = []
for i in range(self.num_prompts):
exist_patch = i in regional_lora
if exist_patch:
patch_model(new_model, regional_lora[i])
retval = model_function(
chunk_or_none(input, self.num_prompts, i),
chunk_or_none(timestep, self.num_prompts, i),
chunk_or_none(c_concat, self.num_prompts, i),
chunk_or_none(c_crossattn, self.num_prompts, i),
chunk_or_none(c_adm, self.num_prompts, i),
chunk_or_none_for_control(control, self.num_prompts, i),
transformer_options,
)
retvals.append(retval)
if exist_patch:
unpatch_model(new_model)
return torch.cat(retvals, dim=0)
new_model.set_model_unet_function_wrapper(apply_regional)
return (new_model, )
def patch_model(model, patches):
model.backup_2 = {}
model_sd = model.model_state_dict()
for key in patches:
if key not in model_sd:
print("could not patch. key doesn't exist in model:", key)
continue
weight = model_sd[key]
if key not in model.backup_2:
model.backup_2[key] = weight
temp_weight = weight.to(torch.float32, copy=True)
out_weight = model.calculate_weight(patches[key], temp_weight, key).to(weight.dtype)
comfy.utils.set_attr(model.model, key, out_weight)
del temp_weight
def unpatch_model(model):
keys = list(model.backup_2.keys())
for k in keys:
comfy.utils.set_attr(model.model, k, model.backup_2[k])
model.backup_2 = {}
NODE_CLASS_MAPPINGS = {
"RegionalLoRALoader": RegionalLoRALoader,
"ApplyRegionalLoRA": ApplyRegionalLoRA,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"RegionalLoRALoader": "Regional LoRA Loader",
"ApplyRegionalLoRA": "Apply Regional LoRA",
}
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment