Last active
October 20, 2021 13:42
-
-
Save Guitaricet/7a8fb70514bd6ad0f7d5aca5bdf17cfa to your computer and use it in GitHub Desktop.
Very simple self attention implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Just dot product self-attention | |
class SelfAttention(nn.Module): | |
def __init__(self, dim=7): | |
super().__init__() | |
self.K = nn.Linear(dim, dim) | |
self.Q = nn.Linear(dim, dim) | |
self.V = nn.Linear(dim, dim) | |
self.scale = dim ** 0.5 | |
def forward(self, x): | |
k, q, v = self.K(x), self.Q(x), self.V(x) # (bs, seq, hid) | |
alpha = F.softmax(k @ q.transpose(1, 2) / self.scale, dim=-1) # (bs, seq, hid) @ (bs, hid, seq) | |
return alpha @ v # (bs, seq, seq) @ (bs, seq, hid) | |
# To make it multi-head we need to add quite some code | |
# einops helps to write more clear shape transformations | |
from einops import rearrange | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, dim=8, n_heads=1): | |
super().__init__() | |
assert dim % n_heads == 0 | |
self.K = nn.Linear(dim, dim) | |
self.Q = nn.Linear(dim, dim) | |
self.V = nn.Linear(dim, dim) | |
self.U = nn.Linear(dim, dim) # used to mix the heads | |
self.scale = (dim / n_heads) ** 0.5 | |
self.n_heads = n_heads | |
def forward(self, x): | |
bsz, seq, dim = x.shape | |
head_dim = dim // self.n_heads | |
k, q, v = self.K(x), self.Q(x), self.V(x) # (bs, seq, hid) | |
# split heads - process them independently, just like different elements in the batch | |
k = rearrange(k, 'bs seq (head k) -> (bs head) seq k', head=self.n_heads) | |
q = rearrange(q, 'bs seq (head k) -> (bs head) seq k', head=self.n_heads) | |
v = rearrange(v, 'bs seq (head k) -> (bs head) seq k', head=self.n_heads) | |
alpha = F.softmax(k @ q.transpose(1, 2) / self.scale, dim=-1) # (bs * head, seq, hid / head) @ (bs / head, hid / head, seq) | |
attn = alpha @ v # (bs * head, seq, seq) @ (bs * head, seq, hid / head) | |
attn = rearrange(attn, '(bs head) seq k -> bs seq (head k)', head=self.n_heads) | |
attn = self.U(attn) | |
return attn | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment