Skip to content

Instantly share code, notes, and snippets.

@matthewchung74
Created February 5, 2021 23:36
Show Gist options
  • Save matthewchung74/eaab6d0ba18cfb3e10837e36e3a86cff to your computer and use it in GitHub Desktop.
Save matthewchung74/eaab6d0ba18cfb3e10837e36e3a86cff to your computer and use it in GitHub Desktop.
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # create qkv
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale # q multiplied by k
attn = attn.softmax(dim=-1) # softmax
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment