Last active
April 11, 2024 21:23
-
-
Save crowsonkb/feb45795bb8e86d665db25570d317726 to your computer and use it in GitHub Desktop.
Mixture of Softmaxes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Mixture of Softmaxes""" | |
import torch | |
from torch.nn import functional as F | |
class MixtureOfSoftmaxes(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, x, p): | |
with torch.cuda.amp.autocast(enabled=False): | |
x = F.log_softmax(x, dim=-1) | |
p = F.log_softmax(p, dim=-1) | |
out = torch.logsumexp(x + p[..., None], dim=-2) | |
ctx.save_for_backward(x, p, out) | |
return out | |
@staticmethod | |
def backward(ctx, grad_output): | |
with torch.cuda.amp.autocast(enabled=False): | |
x, p, out = ctx.saved_tensors | |
grad_x = torch.exp(x + p[..., None] - out[..., None, :]) * grad_output[..., None, :] | |
grad_p = torch.sum(grad_x, dim=-1) | |
grad_x -= torch.exp(x) * torch.sum(grad_x, dim=-1, keepdim=True) | |
grad_p -= torch.exp(p) * torch.sum(grad_p, dim=-1, keepdim=True) | |
return grad_x, grad_p | |
def mixture_of_softmaxes(x, p): | |
"""Returns log probabilities of a mixture of k softmaxes. | |
Args: | |
x: The input logits, shape (..., k, n_classes). | |
p: The mixture logits, shape (..., k). | |
Returns: | |
The log probabilities, shape (..., n_classes). | |
""" | |
return MixtureOfSoftmaxes.apply(x, p.to(x.dtype)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment