Created
April 21, 2018 17:31
-
-
Save cbaziotis/94e53bdd6e4852756e0395560ff38aa4 to your computer and use it in GitHub Desktop.
SelfAttention implementation in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SelfAttention(nn.Module): | |
def __init__(self, attention_size, batch_first=False, non_linearity="tanh"): | |
super(SelfAttention, self).__init__() | |
self.batch_first = batch_first | |
self.attention_weights = Parameter(torch.FloatTensor(attention_size)) | |
self.softmax = nn.Softmax(dim=-1) | |
if non_linearity == "relu": | |
self.non_linearity = nn.ReLU() | |
else: | |
self.non_linearity = nn.Tanh() | |
init.uniform(self.attention_weights.data, -0.005, 0.005) | |
def get_mask(self, attentions, lengths): | |
""" | |
Construct mask for padded itemsteps, based on lengths | |
""" | |
max_len = max(lengths.data) | |
mask = Variable(torch.ones(attentions.size())).detach() | |
if attentions.data.is_cuda: | |
mask = mask.cuda() | |
for i, l in enumerate(lengths.data): # skip the first sentence | |
if l < max_len: | |
mask[i, l:] = 0 | |
return mask | |
def forward(self, inputs, lengths): | |
################################################################## | |
# STEP 1 - perform dot product | |
# of the attention vector and each hidden state | |
################################################################## | |
# inputs is a 3D Tensor: batch, len, hidden_size | |
# scores is a 2D Tensor: batch, len | |
scores = self.non_linearity(inputs.matmul(self.attention_weights)) | |
scores = self.softmax(scores) | |
################################################################## | |
# Step 2 - Masking | |
################################################################## | |
# construct a mask, based on the sentence lengths | |
mask = self.get_mask(scores, lengths) | |
# apply the mask - zero out masked timesteps | |
masked_scores = scores * mask | |
# re-normalize the masked scores | |
_sums = masked_scores.sum(-1, keepdim=True) # sums per row | |
scores = masked_scores.div(_sums) # divide by row sum | |
################################################################## | |
# Step 3 - Weighted sum of hidden states, by the attention scores | |
################################################################## | |
# multiply each hidden state with the attention weights | |
weighted = torch.mul(inputs, scores.unsqueeze(-1).expand_as(inputs)) | |
# sum the hidden states | |
representations = weighted.sum(1).squeeze() | |
return representations, scores |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment