Skip to content

Instantly share code, notes, and snippets.

@opparco
Created September 5, 2022 09:21
Show Gist options
  • Save opparco/3c2c898fe5b03cfb051749217fa9a081 to your computer and use it in GitHub Desktop.
Save opparco/3c2c898fe5b03cfb051749217fa9a081 to your computer and use it in GitHub Desktop.
diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py
index f4eff39..f90c6c4 100644
--- a/ldm/modules/attention.py
+++ b/ldm/modules/attention.py
@@ -174,23 +174,27 @@ class CrossAttention(nn.Module):
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
+ del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+ del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
+ del mask
# attention, what we cannot get enough of
- attn = sim.softmax(dim=-1)
+ sim[4:] = sim[4:].softmax(dim=-1)
+ sim[:4] = sim[:4].softmax(dim=-1)
- out = einsum('b i j, b j d -> b i d', attn, v)
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
- return self.to_out(out)
+ sim = einsum('b i j, b j d -> b i d', sim, v)
+ sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(sim)
class BasicTransformerBlock(nn.Module):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment