Skip to content

Instantly share code, notes, and snippets.

Created August 15, 2024 12:31
Show Gist options
  • Save davidad/b09e03e04f523cf1526d9e146a52842a to your computer and use it in GitHub Desktop.
Save davidad/b09e03e04f523cf1526d9e146a52842a to your computer and use it in GitHub Desktop.
TV- and JS- barycentres in PyTorch
# Jensen-Shannon Barycentre in PyTorch
# Adapted by davidad, 2024-08,
# from `jensen-shannon-centroid` by Dustin Wright, 2023,
# which was based on Frank Nielsen, 2020, "On a Generalization of the Jensen–Shannon Divergence and the Jensen–Shannon Centroid", Entropy 22.
import torch
from typing import List, Union
import logging
def delta_f(theta: torch.Tensor) -> torch.Tensor:
Gradient of the negentropy w.r.t. the natural parameter \\theta (equation 96 from Nielsen, 2020).
:param theta: torch.Tensor
:return: torch.Tensor
norm = torch.clamp(1 - theta.sum(dim=-1, keepdim=True), min=1e-8)
return torch.log(theta / norm)
def delta_f_inv(eta: torch.Tensor) -> torch.Tensor:
Inverse gradient of the negentropy (equation 97 from Nielsen, 2020).
:param eta: torch.Tensor
:return: torch.Tensor
norm = 1 + torch.exp(eta).sum(dim=-1, keepdim=True)
return torch.exp(eta) / norm
def barycentre_jensen_shannon(
logits: torch.Tensor,
T: int = 1000,
eps: float = 1e-10) -> torch.Tensor:
Calculate the Jensen-Shannon barycentre of a set of categorical distributions.
:param logits: A torch.Tensor of size NxMxK
:param T: The maximum number of optimization steps.
:param eps: Minimum difference between distributions at t and t + 1 needed for convergence
:return: An array of size NxK, which is the Jensen-Shannon centroid between the M distributions for each of the N ensembles.
assert logits.dim() == 3, f"Shape of distributions should be 3D, found {logits[0].dim()}"
N, M, K = logits.shape
# Convert to natural parameters
prob_dists = torch.nn.functional.softmax(logits,dim=-1)
natural_dists = prob_dists[:, :, :-1]
theta = natural_dists.mean(1).unsqueeze(1)
converged = False
for t in range(T):
dfs = delta_f([theta,natural_dists],axis=1)).mean(1)
theta_new = delta_f_inv(dfs).unsqueeze(1)
# Stop if there's no significant difference
if torch.abs(theta - theta_new).sum() < eps:
logging.debug(f"Jensen-Shannon centroid converged after {t} iterations")
converged = True
theta = theta_new
assert converged, f"Couldn't converge after {T} iterations!"
prob_dist =[theta, 1 - theta.sum(dim=-1, keepdim=True)], dim=-1).squeeze(1)
return torch.log(prob_dist)
distributions = torch.Tensor([
[[0.17, 0.8, 0.03],
[0.21, 0.75, 0.04],
[0.01, 0.01, 0.98]],
[[0.2, 0.79, 0.01],
[0.24, 0.75, 0.01],
[0.49, 0.5, 0.01]]
logits = torch.log(distributions)
result = barycentre_jensen_shannon(logits)
# Total Variation Barycentre in PyTorch
# Written by davidad, 2024-08
# with help from ChatGPT 4
import torch
import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer
def barycentre_total_variation(logits):
assert logits.dim() == 3, f"Shape of distributions should be 3D, found {logits[0].dim()}"
distributions = torch.nn.functional.softmax(logits,dim=-1)
N, M, K = distributions.shape # N = Batch size, M = number of predictors, K = support size
# Reshape distributions to 2D: combine N and M into one dimension
dist_reshaped = distributions.reshape(N * M, K)
# Variables and parameters for CVXPY
Q = cp.Variable((N, K))
z = cp.Variable((N * M, K))
P = cp.Parameter((N * M, K))
# Constraints
constraints = [
cp.sum(Q, axis=1) == 1, # Each row of Q must sum to 1
Q >= 0, # Non-negativity of the probabilities
# Add constraints for absolute values using a loop over combined batch and distribution dimension
for i in range(N * M):
n = i // M # Map back to the batch index
constraints += [
z[i, :] >= P[i, :] - Q[n, :],
z[i, :] >= Q[n, :] - P[i, :]
# Objective: minimize the sum of z
objective = cp.Minimize(cp.sum(z))
problem = cp.Problem(objective, constraints)
# Create a CVXPY layer
cvxpylayer = CvxpyLayer(problem, parameters=[P], variables=[Q])
# Solve the problem
Q_sol, = cvxpylayer(dist_reshaped)
return torch.log(Q_sol)
# Example usage
distributions = torch.Tensor([
[[0.17, 0.8, 0.03],
[0.21, 0.75, 0.04],
[0.01, 0.01, 0.98]],
[[0.2, 0.79, 0.01],
[0.24, 0.75, 0.01],
[0.49, 0.5, 0.01]]
logits = torch.log(distributions)
result = barycentre_total_variation(logits)
# tensor([[0.1836, 0.7631, 0.0533],
# [0.2397, 0.7503, 0.0100]])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment