Last active
April 18, 2023 18:31
-
-
Save albertbuchard/5e1aeca423a2a5d080604f8cd11f65d7 to your computer and use it in GitHub Desktop.
This code defines a PyTorch implementation of the Sparsemax activation function. Sparsemax is an alternative to the softmax activation function that produces sparse probability distributions (euclidian projection to the simplex). The implementation is provided as a PyTorch nn.Module, making it easy to integrate into any architecture.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
class Sparsemax(nn.Module): | |
def __init__(self, dim=-1): | |
super(Sparsemax, self).__init__() | |
self.dim = dim | |
def forward(self, x): | |
# Move the dimension to apply Sparsemax to the front | |
x = x.transpose(self.dim, -1) | |
# Calculate the cumulative sum of the sorted input | |
z, _ = torch.sort(x, dim=-1, descending=True) | |
cumsums = torch.cumsum(z, dim=-1) | |
# Project to the simplex; see details in https://arxiv.org/pdf/1602.02068.pdf | |
K = torch.arange(1, x.shape[-1] + 1, device=x.device) | |
K = K.repeat(*x.shape[:-1], 1) | |
support = 1 + K * z - cumsums > 0 | |
k_z = (K * support).max(dim=-1, keepdim=True).values | |
# Compute the threshold and apply it to the input | |
# (k_z - 1) is necessary to correct for the 1-indexing in the paper | |
cumsums_element = torch.gather(cumsums, dim=-1, index=(k_z - 1)) | |
thresholds = (cumsums_element - 1) / k_z | |
output = torch.clamp(x - thresholds, min=0) | |
# Transpose back the dimensions | |
output = output.transpose(self.dim, -1) | |
return output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment