Created
August 15, 2024 12:31
-
-
Save davidad/b09e03e04f523cf1526d9e146a52842a to your computer and use it in GitHub Desktop.
TV- and JS- barycentres in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Jensen-Shannon Barycentre in PyTorch | |
# Adapted by davidad, 2024-08, | |
# from `jensen-shannon-centroid` by Dustin Wright, 2023, | |
# which was based on Frank Nielsen, 2020, "On a Generalization of the Jensen–Shannon Divergence and the Jensen–Shannon Centroid", Entropy 22. | |
import torch | |
from typing import List, Union | |
import logging | |
def delta_f(theta: torch.Tensor) -> torch.Tensor: | |
""" | |
Gradient of the negentropy w.r.t. the natural parameter \\theta (equation 96 from Nielsen, 2020). | |
:param theta: torch.Tensor | |
:return: torch.Tensor | |
""" | |
norm = torch.clamp(1 - theta.sum(dim=-1, keepdim=True), min=1e-8) | |
return torch.log(theta / norm) | |
def delta_f_inv(eta: torch.Tensor) -> torch.Tensor: | |
""" | |
Inverse gradient of the negentropy (equation 97 from Nielsen, 2020). | |
:param eta: torch.Tensor | |
:return: torch.Tensor | |
""" | |
norm = 1 + torch.exp(eta).sum(dim=-1, keepdim=True) | |
return torch.exp(eta) / norm | |
def barycentre_jensen_shannon( | |
logits: torch.Tensor, | |
T: int = 1000, | |
eps: float = 1e-10) -> torch.Tensor: | |
""" | |
Calculate the Jensen-Shannon barycentre of a set of categorical distributions. | |
:param logits: A torch.Tensor of size NxMxK | |
:param T: The maximum number of optimization steps. | |
:param eps: Minimum difference between distributions at t and t + 1 needed for convergence | |
:return: An array of size NxK, which is the Jensen-Shannon centroid between the M distributions for each of the N ensembles. | |
""" | |
assert logits.dim() == 3, f"Shape of distributions should be 3D, found {logits[0].dim()}" | |
N, M, K = logits.shape | |
# Convert to natural parameters | |
prob_dists = torch.nn.functional.softmax(logits,dim=-1) | |
natural_dists = prob_dists[:, :, :-1] | |
theta = natural_dists.mean(1).unsqueeze(1) | |
converged = False | |
for t in range(T): | |
dfs = delta_f(torch.cat([theta,natural_dists],axis=1)).mean(1) | |
theta_new = delta_f_inv(dfs).unsqueeze(1) | |
# Stop if there's no significant difference | |
if torch.abs(theta - theta_new).sum() < eps: | |
logging.debug(f"Jensen-Shannon centroid converged after {t} iterations") | |
converged = True | |
break | |
theta = theta_new | |
assert converged, f"Couldn't converge after {T} iterations!" | |
prob_dist = torch.cat([theta, 1 - theta.sum(dim=-1, keepdim=True)], dim=-1).squeeze(1) | |
return torch.log(prob_dist) | |
distributions = torch.Tensor([ | |
[[0.17, 0.8, 0.03], | |
[0.21, 0.75, 0.04], | |
[0.01, 0.01, 0.98]], | |
[[0.2, 0.79, 0.01], | |
[0.24, 0.75, 0.01], | |
[0.49, 0.5, 0.01]] | |
]) | |
logits = torch.log(distributions) | |
result = barycentre_jensen_shannon(logits) | |
print(torch.nn.functional.softmax(result,dim=-1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Total Variation Barycentre in PyTorch | |
# Written by davidad, 2024-08 | |
# with help from ChatGPT 4 | |
import torch | |
import cvxpy as cp | |
from cvxpylayers.torch import CvxpyLayer | |
def barycentre_total_variation(logits): | |
assert logits.dim() == 3, f"Shape of distributions should be 3D, found {logits[0].dim()}" | |
distributions = torch.nn.functional.softmax(logits,dim=-1) | |
N, M, K = distributions.shape # N = Batch size, M = number of predictors, K = support size | |
# Reshape distributions to 2D: combine N and M into one dimension | |
dist_reshaped = distributions.reshape(N * M, K) | |
# Variables and parameters for CVXPY | |
Q = cp.Variable((N, K)) | |
z = cp.Variable((N * M, K)) | |
P = cp.Parameter((N * M, K)) | |
# Constraints | |
constraints = [ | |
cp.sum(Q, axis=1) == 1, # Each row of Q must sum to 1 | |
Q >= 0, # Non-negativity of the probabilities | |
] | |
# Add constraints for absolute values using a loop over combined batch and distribution dimension | |
for i in range(N * M): | |
n = i // M # Map back to the batch index | |
constraints += [ | |
z[i, :] >= P[i, :] - Q[n, :], | |
z[i, :] >= Q[n, :] - P[i, :] | |
] | |
# Objective: minimize the sum of z | |
objective = cp.Minimize(cp.sum(z)) | |
problem = cp.Problem(objective, constraints) | |
# Create a CVXPY layer | |
cvxpylayer = CvxpyLayer(problem, parameters=[P], variables=[Q]) | |
# Solve the problem | |
Q_sol, = cvxpylayer(dist_reshaped) | |
return torch.log(Q_sol) | |
# Example usage | |
distributions = torch.Tensor([ | |
[[0.17, 0.8, 0.03], | |
[0.21, 0.75, 0.04], | |
[0.01, 0.01, 0.98]], | |
[[0.2, 0.79, 0.01], | |
[0.24, 0.75, 0.01], | |
[0.49, 0.5, 0.01]] | |
]) | |
logits = torch.log(distributions) | |
result = barycentre_total_variation(logits) | |
print(torch.nn.functional.softmax(result,dim=-1)) | |
# tensor([[0.1836, 0.7631, 0.0533], | |
# [0.2397, 0.7503, 0.0100]]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment