Created
April 23, 2025 21:29
-
-
Save rockerBOO/a03d169443db24a37039c83ae1c58253 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/comfy/lora.py b/comfy/lora.py | |
index 8760a21f..e6456509 100644 | |
--- a/comfy/lora.py | |
+++ b/comfy/lora.py | |
@@ -44,6 +44,12 @@ def load_lora(lora, to_load, log_missing=True): | |
alpha = lora[alpha_name].item() | |
loaded_keys.add(alpha_name) | |
+ aid_p_name = "{}.aid_p".format(x) | |
+ aid_p = None | |
+ if aid_p_name in lora.keys(): | |
+ aid_p = lora[aid_p_name].item() | |
+ loaded_keys.add(aid_p_name) | |
+ | |
dora_scale_name = "{}.dora_scale".format(x) | |
dora_scale = None | |
if dora_scale_name in lora.keys(): | |
@@ -53,6 +59,8 @@ def load_lora(lora, to_load, log_missing=True): | |
for adapter_cls in weight_adapter.adapters: | |
adapter = adapter_cls.load(x, lora, alpha, dora_scale, loaded_keys) | |
if adapter is not None: | |
+ if hasattr(adapter, "add_aid") and aid_p is not None: | |
+ adapter.add_aid(aid_p) | |
patch_dict[to_load[x]] = adapter | |
loaded_keys.update(adapter.loaded_keys) | |
continue | |
diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py | |
index b2e62392..5f66a16a 100644 | |
--- a/comfy/weight_adapter/lora.py | |
+++ b/comfy/weight_adapter/lora.py | |
@@ -5,6 +5,20 @@ import torch | |
import comfy.model_management | |
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape | |
+class AID(torch.nn.Module): | |
+ def __init__(self, p=0.9): | |
+ super().__init__() | |
+ self.p = p | |
+ | |
+ def forward(self, x: torch.Tensor): | |
+ if self.training: | |
+ pos_mask = (x >= 0) * torch.bernoulli(torch.ones_like(x) * self.p) | |
+ neg_mask = (x < 0) * torch.bernoulli(torch.ones_like(x) * (1 - self.p)) | |
+ return x * (pos_mask + neg_mask) | |
+ else: | |
+ pos_part = (x >= 0) * x * self.p | |
+ neg_part = (x < 0) * x * (1 - self.p) | |
+ return pos_part + neg_part | |
class LoRAAdapter(WeightAdapterBase): | |
name = "lora" | |
@@ -13,6 +27,8 @@ class LoRAAdapter(WeightAdapterBase): | |
self.loaded_keys = loaded_keys | |
self.weights = weights | |
+ self.aid: AID | None = None | |
+ | |
@classmethod | |
def load( | |
cls, | |
@@ -78,6 +94,9 @@ class LoRAAdapter(WeightAdapterBase): | |
else: | |
return None | |
+ def add_aid(self, aid_p: float): | |
+ self.aid = AID(p=aid_p) | |
+ | |
def calculate_weight( | |
self, | |
weight, | |
@@ -125,6 +144,8 @@ class LoRAAdapter(WeightAdapterBase): | |
lora_diff = torch.mm( | |
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) | |
).reshape(weight.shape) | |
+ if self.aid is not None: | |
+ lora_diff = self.aid(lora_diff) | |
if dora_scale is not None: | |
weight = weight_decompose( | |
dora_scale, |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment