Skip to content

Instantly share code, notes, and snippets.

@geyang
Last active March 13, 2020 00:47
Show Gist options
  • Save geyang/606f47fd625ae5bdbad031cef60ae4d9 to your computer and use it in GitHub Desktop.
Save geyang/606f47fd625ae5bdbad031cef60ae4d9 to your computer and use it in GitHub Desktop.
quick sketch of a transformer layer.
class MultiHeadedMlp(nn.Module):
def __init__(self, input_dim, *out_dims):
super().__init__()
self.out_dims = out_dims
self.heads = nn.ModuleList([nn.Linear(input_dim, dim) for dim in out_dims])
def forward(self, i):
return [h(i) for h in self.heads]
class MHA(nn.Module):
"""multi-headed attention as used in transformer"""
def __init__(self, input_dim, d_k, heads=1):
super().__init__()
self.input_dim = input_dim
self.d_k = d_k
self.h = heads
self.value_linear = nn.Linear(input_dim, d_k * heads)
self.key_linear = nn.Linear(input_dim, d_k * heads)
self.query_linear = nn.Linear(input_dim, d_k * heads)
self.head = nn.Sequential(
nn.Linear(input_dim, input_dim * heads),
nn.ReLU()
)
def forward(self, x):
T = 1
# suppose this is a language model
B, L, _ = x.shape
v = self.value_linear(x).reshape(B, L, self.d_k, self.h).transpose(1, 2) # (B, h, L, d_l)
q = self.query_linear(x).reshape(B, L, self.h, self.d_k).transpose(1, 2) # (B, h, L, d_l)
k = self.key_linear(x).reshape(B, L, self.h, self.d_k, ).permute(0, 2, 3, 1) # (B, h, d_k, L)
w = (q @ k / T) / sqrt(self.d_k) # (B, h, L, L)
attn = F.softmax(w, dim=-1) # (B, h, L, L)
applied = (attn @ v).permute(0, 2, 3, 1) # (B, L, h, d_l)
return self.head(torch.cat(applied, dim=-1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment