Skip to content

Instantly share code, notes, and snippets.

@rockerBOO
Created April 23, 2025 21:29
Show Gist options
  • Save rockerBOO/a03d169443db24a37039c83ae1c58253 to your computer and use it in GitHub Desktop.
Save rockerBOO/a03d169443db24a37039c83ae1c58253 to your computer and use it in GitHub Desktop.
diff --git a/comfy/lora.py b/comfy/lora.py
index 8760a21f..e6456509 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -44,6 +44,12 @@ def load_lora(lora, to_load, log_missing=True):
alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name)
+ aid_p_name = "{}.aid_p".format(x)
+ aid_p = None
+ if aid_p_name in lora.keys():
+ aid_p = lora[aid_p_name].item()
+ loaded_keys.add(aid_p_name)
+
dora_scale_name = "{}.dora_scale".format(x)
dora_scale = None
if dora_scale_name in lora.keys():
@@ -53,6 +59,8 @@ def load_lora(lora, to_load, log_missing=True):
for adapter_cls in weight_adapter.adapters:
adapter = adapter_cls.load(x, lora, alpha, dora_scale, loaded_keys)
if adapter is not None:
+ if hasattr(adapter, "add_aid") and aid_p is not None:
+ adapter.add_aid(aid_p)
patch_dict[to_load[x]] = adapter
loaded_keys.update(adapter.loaded_keys)
continue
diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py
index b2e62392..5f66a16a 100644
--- a/comfy/weight_adapter/lora.py
+++ b/comfy/weight_adapter/lora.py
@@ -5,6 +5,20 @@ import torch
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
+class AID(torch.nn.Module):
+ def __init__(self, p=0.9):
+ super().__init__()
+ self.p = p
+
+ def forward(self, x: torch.Tensor):
+ if self.training:
+ pos_mask = (x >= 0) * torch.bernoulli(torch.ones_like(x) * self.p)
+ neg_mask = (x < 0) * torch.bernoulli(torch.ones_like(x) * (1 - self.p))
+ return x * (pos_mask + neg_mask)
+ else:
+ pos_part = (x >= 0) * x * self.p
+ neg_part = (x < 0) * x * (1 - self.p)
+ return pos_part + neg_part
class LoRAAdapter(WeightAdapterBase):
name = "lora"
@@ -13,6 +27,8 @@ class LoRAAdapter(WeightAdapterBase):
self.loaded_keys = loaded_keys
self.weights = weights
+ self.aid: AID | None = None
+
@classmethod
def load(
cls,
@@ -78,6 +94,9 @@ class LoRAAdapter(WeightAdapterBase):
else:
return None
+ def add_aid(self, aid_p: float):
+ self.aid = AID(p=aid_p)
+
def calculate_weight(
self,
weight,
@@ -125,6 +144,8 @@ class LoRAAdapter(WeightAdapterBase):
lora_diff = torch.mm(
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
).reshape(weight.shape)
+ if self.aid is not None:
+ lora_diff = self.aid(lora_diff)
if dora_scale is not None:
weight = weight_decompose(
dora_scale,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment