Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active April 11, 2024 21:23
Show Gist options
  • Save crowsonkb/feb45795bb8e86d665db25570d317726 to your computer and use it in GitHub Desktop.
Save crowsonkb/feb45795bb8e86d665db25570d317726 to your computer and use it in GitHub Desktop.
Mixture of Softmaxes
"""Mixture of Softmaxes"""
import torch
from torch.nn import functional as F
class MixtureOfSoftmaxes(torch.autograd.Function):
@staticmethod
def forward(ctx, x, p):
with torch.cuda.amp.autocast(enabled=False):
x = F.log_softmax(x, dim=-1)
p = F.log_softmax(p, dim=-1)
out = torch.logsumexp(x + p[..., None], dim=-2)
ctx.save_for_backward(x, p, out)
return out
@staticmethod
def backward(ctx, grad_output):
with torch.cuda.amp.autocast(enabled=False):
x, p, out = ctx.saved_tensors
grad_x = torch.exp(x + p[..., None] - out[..., None, :]) * grad_output[..., None, :]
grad_p = torch.sum(grad_x, dim=-1)
grad_x -= torch.exp(x) * torch.sum(grad_x, dim=-1, keepdim=True)
grad_p -= torch.exp(p) * torch.sum(grad_p, dim=-1, keepdim=True)
return grad_x, grad_p
def mixture_of_softmaxes(x, p):
"""Returns log probabilities of a mixture of k softmaxes.
Args:
x: The input logits, shape (..., k, n_classes).
p: The mixture logits, shape (..., k).
Returns:
The log probabilities, shape (..., n_classes).
"""
return MixtureOfSoftmaxes.apply(x, p.to(x.dtype))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment