Created
September 21, 2023 12:31
-
-
Save laksjdjf/a5f1278137f9a7e979dc875784efaaf7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import comfy | |
import copy | |
def chunk_or_none(x, chunk_size, index): | |
if x is None: | |
return None | |
return x.chunk(chunk_size, dim=0)[index] | |
def chunk_or_none_for_control(x, chunk_size, index): | |
if x is None: | |
return None | |
else: | |
control = {} | |
if "input" in x: | |
control["input"] = [s.chunk(chunk_size, dim=0)[index] for s in x["input"]] | |
if "middle" in x: | |
control["middle"] = [s.chunk(chunk_size, dim=0)[index] for s in x["middle"]] | |
if "output" in x: | |
control["output"] = [s.chunk(chunk_size, dim=0)[index] for s in x["output"]] | |
return control | |
def add_patches(regional_lora, patches, number): | |
if number in regional_lora: | |
for key in patches: | |
if key in regional_lora[number]: | |
regional_lora[number][key].extend(patches[key]) | |
else: | |
regional_lora[number][key] = patches[key] | |
else: | |
regional_lora[number] = {} | |
for key in patches: | |
regional_lora[number][key] = patches[key] | |
return regional_lora | |
class RegionalLoRALoader: | |
@classmethod | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"model": ("MODEL", ), | |
"prompt_id": ("INT", { | |
"default": 1, | |
"min": 0, | |
"max": 10, | |
"step": 1, | |
"display": "number" | |
}), | |
}, | |
"optional": { | |
"regional_lora": ("REGIONAL_LORA", ), | |
}, | |
} | |
RETURN_TYPES = ("REGIONAL_LORA", ) | |
FUNCTION = "apply" | |
CATEGORY = "regional_lora" | |
def apply(self, model, prompt_id, regional_lora=None): | |
if regional_lora is None: | |
new_regional_lora = {} | |
else: | |
new_regional_lora = copy.copy(regional_lora) | |
new_regional_lora = add_patches(new_regional_lora, model.patches, prompt_id) | |
return (new_regional_lora, ) | |
class ApplyRegionalLoRA: | |
@classmethod | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"model": ("MODEL", ), | |
"regional_lora": ("REGIONAL_LORA", ), | |
"num_prompts": ("INT", { | |
"default": 1, | |
"min": 0, | |
"max": 10, | |
"step": 1, | |
"display": "number" | |
}), | |
}, | |
} | |
RETURN_TYPES = ("MODEL", ) | |
FUNCTION = "apply" | |
CATEGORY = "regional_lora" | |
def apply(self, model, regional_lora, num_prompts): | |
self.num_prompts = num_prompts | |
new_model = model.clone() | |
def apply_regional(model_function, kwargs): | |
input = kwargs["input"] | |
timestep = kwargs["timestep"] | |
c_concat=kwargs["c"].get("c_concat", None) | |
c_crossattn = kwargs["c"].get("c_crossattn", None) | |
c_adm = kwargs["c"].get("c_adm", None) | |
control = kwargs["c"].get("control", None) | |
transformer_options = kwargs["c"].get("transformer_options", None) | |
retvals = [] | |
for i in range(self.num_prompts): | |
exist_patch = i in regional_lora | |
if exist_patch: | |
patch_model(new_model, regional_lora[i]) | |
retval = model_function( | |
chunk_or_none(input, self.num_prompts, i), | |
chunk_or_none(timestep, self.num_prompts, i), | |
chunk_or_none(c_concat, self.num_prompts, i), | |
chunk_or_none(c_crossattn, self.num_prompts, i), | |
chunk_or_none(c_adm, self.num_prompts, i), | |
chunk_or_none_for_control(control, self.num_prompts, i), | |
transformer_options, | |
) | |
retvals.append(retval) | |
if exist_patch: | |
unpatch_model(new_model) | |
return torch.cat(retvals, dim=0) | |
new_model.set_model_unet_function_wrapper(apply_regional) | |
return (new_model, ) | |
def patch_model(model, patches): | |
model.backup_2 = {} | |
model_sd = model.model_state_dict() | |
for key in patches: | |
if key not in model_sd: | |
print("could not patch. key doesn't exist in model:", key) | |
continue | |
weight = model_sd[key] | |
if key not in model.backup_2: | |
model.backup_2[key] = weight | |
temp_weight = weight.to(torch.float32, copy=True) | |
out_weight = model.calculate_weight(patches[key], temp_weight, key).to(weight.dtype) | |
comfy.utils.set_attr(model.model, key, out_weight) | |
del temp_weight | |
def unpatch_model(model): | |
keys = list(model.backup_2.keys()) | |
for k in keys: | |
comfy.utils.set_attr(model.model, k, model.backup_2[k]) | |
model.backup_2 = {} | |
NODE_CLASS_MAPPINGS = { | |
"RegionalLoRALoader": RegionalLoRALoader, | |
"ApplyRegionalLoRA": ApplyRegionalLoRA, | |
} | |
NODE_DISPLAY_NAME_MAPPINGS = { | |
"RegionalLoRALoader": "Regional LoRA Loader", | |
"ApplyRegionalLoRA": "Apply Regional LoRA", | |
} | |
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment