Skip to content

Instantly share code, notes, and snippets.

@kaniblu
Last active September 1, 2021 13:19
Show Gist options
  • Save kaniblu/94f3ede72d1651b087a561cf80b306ca to your computer and use it in GitHub Desktop.
Save kaniblu/94f3ede72d1651b087a561cf80b306ca to your computer and use it in GitHub Desktop.
Masked Softmax in PyTorch
import torch
import torch.nn as nn
class MaskedSoftmax(nn.Module):
def __init__(self):
super(MaskedSoftmax, self).__init__()
self.softmax = nn.Softmax(1)
def forward(self, x, mask=None):
"""
Performs masked softmax, as simply masking post-softmax can be
inaccurate
:param x: [batch_size, num_items]
:param mask: [batch_size, num_items]
:return:
"""
if mask is not None:
mask = mask.float()
if mask is not None:
x_masked = x * mask + (1 - 1 / mask)
else:
x_masked = x
x_max = x_masked.max(1)[0]
x_exp = (x - x_max.unsqueeze(-1)).exp()
if mask is not None:
x_exp = x_exp * mask.float()
return x_exp / x_exp.sum(1).unsqueeze(-1)
@zimonitrome
Copy link

Hello, I found this post via Google.

I made a smaller, functional version that works for any tensor shape:

import torch

def masked_softmax(x, mask, **kwargs):
    x_masked = x.clone()
    x_masked[mask == 0] = -float("inf")

    return torch.softmax(x_masked, **kwargs)

If anyone else ever needs it :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment