Created
February 4, 2019 22:10
-
-
Save shreydesai/3b4c5ee9ea135a7693c5886078257371 to your computer and use it in GitHub Desktop.
PyTorch Scaled Dot Product Attention
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import numpy as np | |
class DotProductAttention(nn.Module): | |
def __init__(self, query_dim, key_dim, value_dim): | |
super().__init__() | |
self.scale = 1.0/np.sqrt(query_dim) | |
self.softmax = nn.Softmax(dim=2) | |
def forward(self, mask, query, keys, values): | |
# query: [B,Q] (hidden state, decoder output, etc.) | |
# keys: [T,B,K] (encoder outputs) | |
# values: [T,B,V] (encoder outputs) | |
# assume Q == K | |
# compute energy | |
query = query.unsqueeze(1) # [B,Q] -> [B,1,Q] | |
keys = keys.permute(1,2,0) # [T,B,K] -> [B,K,T] | |
energy = torch.bmm(query, keys) # [B,1,Q]*[B,K,T] = [B,1,T] | |
energy = self.softmax(energy.mul_(self.scale)) | |
# apply mask, renormalize | |
energy = energy*mask | |
energy.div(energy.sum(2, keepdim=True)) | |
# weight values | |
values = values.transpose(0,1) # [T,B,V] -> [B,T,V] | |
combo = torch.bmm(energy, values).squeeze(1) # [B,1,T]*[B,T,V] -> [B,V] | |
return (combo, energy) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@jaypatravali, ah yes you're right. I misinterpreted the line in the PyTorch softmax docs that reads "so every slice along dim will sum to 1". Thanks!