Skip to content

Instantly share code, notes, and snippets.

Last active October 20, 2021 13:42
Show Gist options
  • Save Guitaricet/7a8fb70514bd6ad0f7d5aca5bdf17cfa to your computer and use it in GitHub Desktop.
Save Guitaricet/7a8fb70514bd6ad0f7d5aca5bdf17cfa to your computer and use it in GitHub Desktop.
Very simple self attention implementation
# Just dot product self-attention
class SelfAttention(nn.Module):
def __init__(self, dim=7):
self.K = nn.Linear(dim, dim)
self.Q = nn.Linear(dim, dim)
self.V = nn.Linear(dim, dim)
self.scale = dim ** 0.5
def forward(self, x):
k, q, v = self.K(x), self.Q(x), self.V(x) # (bs, seq, hid)
alpha = F.softmax(k @ q.transpose(1, 2) / self.scale, dim=-1) # (bs, seq, hid) @ (bs, hid, seq)
return alpha @ v # (bs, seq, seq) @ (bs, seq, hid)
# To make it multi-head we need to add quite some code
# einops helps to write more clear shape transformations
from einops import rearrange
class MultiHeadAttention(nn.Module):
def __init__(self, dim=8, n_heads=1):
assert dim % n_heads == 0
self.K = nn.Linear(dim, dim)
self.Q = nn.Linear(dim, dim)
self.V = nn.Linear(dim, dim)
self.U = nn.Linear(dim, dim) # used to mix the heads
self.scale = (dim / n_heads) ** 0.5
self.n_heads = n_heads
def forward(self, x):
bsz, seq, dim = x.shape
head_dim = dim // self.n_heads
k, q, v = self.K(x), self.Q(x), self.V(x) # (bs, seq, hid)
# split heads - process them independently, just like different elements in the batch
k = rearrange(k, 'bs seq (head k) -> (bs head) seq k', head=self.n_heads)
q = rearrange(q, 'bs seq (head k) -> (bs head) seq k', head=self.n_heads)
v = rearrange(v, 'bs seq (head k) -> (bs head) seq k', head=self.n_heads)
alpha = F.softmax(k @ q.transpose(1, 2) / self.scale, dim=-1) # (bs * head, seq, hid / head) @ (bs / head, hid / head, seq)
attn = alpha @ v # (bs * head, seq, seq) @ (bs * head, seq, hid / head)
attn = rearrange(attn, '(bs head) seq k -> bs seq (head k)', head=self.n_heads)
attn = self.U(attn)
return attn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment