Skip to content

Instantly share code, notes, and snippets.

@shreydesai
Created February 4, 2019 22:10
Show Gist options
  • Save shreydesai/3b4c5ee9ea135a7693c5886078257371 to your computer and use it in GitHub Desktop.
Save shreydesai/3b4c5ee9ea135a7693c5886078257371 to your computer and use it in GitHub Desktop.
PyTorch Scaled Dot Product Attention
import torch
import torch.nn as nn
import numpy as np
class DotProductAttention(nn.Module):
def __init__(self, query_dim, key_dim, value_dim):
super().__init__()
self.scale = 1.0/np.sqrt(query_dim)
self.softmax = nn.Softmax(dim=2)
def forward(self, mask, query, keys, values):
# query: [B,Q] (hidden state, decoder output, etc.)
# keys: [T,B,K] (encoder outputs)
# values: [T,B,V] (encoder outputs)
# assume Q == K
# compute energy
query = query.unsqueeze(1) # [B,Q] -> [B,1,Q]
keys = keys.permute(1,2,0) # [T,B,K] -> [B,K,T]
energy = torch.bmm(query, keys) # [B,1,Q]*[B,K,T] = [B,1,T]
energy = self.softmax(energy.mul_(self.scale))
# apply mask, renormalize
energy = energy*mask
energy.div(energy.sum(2, keepdim=True))
# weight values
values = values.transpose(0,1) # [T,B,V] -> [B,T,V]
combo = torch.bmm(energy, values).squeeze(1) # [B,1,T]*[B,T,V] -> [B,V]
return (combo, energy)
@jaypatravali
Copy link

@yannikkumar, the softmax here taken is across time, because the attention distribution is along the time sequence.

@yanniknoc
Copy link

@jaypatravali, ah yes you're right. I misinterpreted the line in the PyTorch softmax docs that reads "so every slice along dim will sum to 1". Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment