Skip to content

Instantly share code, notes, and snippets.

@shreydesai
Last active October 7, 2024 22:59
Show Gist options
  • Save shreydesai/fc20a99b56392930b34489e20a0c7f88 to your computer and use it in GitHub Desktop.
Save shreydesai/fc20a99b56392930b34489e20a0c7f88 to your computer and use it in GitHub Desktop.
PyTorch Additive Attention
import torch
import torch.nn as nn
class AdditiveAttention(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.hidden_dim = hidden_dim
# attention
self.W = nn.Linear(hidden_dim*2, hidden_dim) # bidirectional
self.tanh = nn.Tanh()
self.v = nn.Parameter(torch.Tensor(hidden_dim, 1)) # context vector
self.softmax = nn.Softmax(dim=2)
# initialization
nn.init.normal_(self.v, 0, 0.1)
def forward(self, mask, query, values):
# mask: [B,1,T] (masking padding tokens)
# query: [B,H] (hidden state, decoder outputs, etc.)
# values: [T,B,H] (outputs to align)
T,B,H = values.size()
# compute energy
query = query.repeat(T,1,1) # [B,H] -> [T,B,H]
feats = torch.cat((query, values), dim=2) # [T,B,H*2]
energy = self.tanh(self.W(feats)) # [T,B,H*2] -> [T,B,H]
# compute attention scores
v = self.v.t().repeat(B,1,1) # [H,1] -> [B,1,H]
energy = energy.permute(1,2,0) # [T,B,H] -> [B,H,T]
scores = torch.bmm(v, energy) # [B,1,H]*[B,H,T] -> [B,1,T]
# apply mask, renormalize
scores = scores*mask
scores.div_(scores.sum(2, keepdim=True))
# weight values
values = values.permute(1,0,2) # [T,B,H] -> [B,T,H]
combo = torch.bmm(scores, values).squeeze(1) # [B,1,T]*[B,T,H] -> [B,H]
return (combo, scores)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment