Last active
October 7, 2024 22:59
-
-
Save shreydesai/fc20a99b56392930b34489e20a0c7f88 to your computer and use it in GitHub Desktop.
PyTorch Additive Attention
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
class AdditiveAttention(nn.Module): | |
def __init__(self, hidden_dim): | |
super().__init__() | |
self.hidden_dim = hidden_dim | |
# attention | |
self.W = nn.Linear(hidden_dim*2, hidden_dim) # bidirectional | |
self.tanh = nn.Tanh() | |
self.v = nn.Parameter(torch.Tensor(hidden_dim, 1)) # context vector | |
self.softmax = nn.Softmax(dim=2) | |
# initialization | |
nn.init.normal_(self.v, 0, 0.1) | |
def forward(self, mask, query, values): | |
# mask: [B,1,T] (masking padding tokens) | |
# query: [B,H] (hidden state, decoder outputs, etc.) | |
# values: [T,B,H] (outputs to align) | |
T,B,H = values.size() | |
# compute energy | |
query = query.repeat(T,1,1) # [B,H] -> [T,B,H] | |
feats = torch.cat((query, values), dim=2) # [T,B,H*2] | |
energy = self.tanh(self.W(feats)) # [T,B,H*2] -> [T,B,H] | |
# compute attention scores | |
v = self.v.t().repeat(B,1,1) # [H,1] -> [B,1,H] | |
energy = energy.permute(1,2,0) # [T,B,H] -> [B,H,T] | |
scores = torch.bmm(v, energy) # [B,1,H]*[B,H,T] -> [B,1,T] | |
# apply mask, renormalize | |
scores = scores*mask | |
scores.div_(scores.sum(2, keepdim=True)) | |
# weight values | |
values = values.permute(1,0,2) # [T,B,H] -> [B,T,H] | |
combo = torch.bmm(scores, values).squeeze(1) # [B,1,T]*[B,T,H] -> [B,H] | |
return (combo, scores) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment