Created
September 5, 2022 09:21
-
-
Save opparco/3c2c898fe5b03cfb051749217fa9a081 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py | |
index f4eff39..f90c6c4 100644 | |
--- a/ldm/modules/attention.py | |
+++ b/ldm/modules/attention.py | |
@@ -174,23 +174,27 @@ class CrossAttention(nn.Module): | |
context = default(context, x) | |
k = self.to_k(context) | |
v = self.to_v(context) | |
+ del context, x | |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | |
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale | |
+ del q, k | |
if exists(mask): | |
mask = rearrange(mask, 'b ... -> b (...)') | |
max_neg_value = -torch.finfo(sim.dtype).max | |
mask = repeat(mask, 'b j -> (b h) () j', h=h) | |
sim.masked_fill_(~mask, max_neg_value) | |
+ del mask | |
# attention, what we cannot get enough of | |
- attn = sim.softmax(dim=-1) | |
+ sim[4:] = sim[4:].softmax(dim=-1) | |
+ sim[:4] = sim[:4].softmax(dim=-1) | |
- out = einsum('b i j, b j d -> b i d', attn, v) | |
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | |
- return self.to_out(out) | |
+ sim = einsum('b i j, b j d -> b i d', sim, v) | |
+ sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h) | |
+ return self.to_out(sim) | |
class BasicTransformerBlock(nn.Module): |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment