Last active
October 17, 2024 01:31
-
-
Save Flunzmas/6e359b118b0730ab403753dcc2a447df to your computer and use it in GitHub Desktop.
Differentiable 2-Wasserstein Distance in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
import torch.linalg as linalg | |
def calculate_2_wasserstein_dist(X, Y): | |
''' | |
Calulates the two components of the 2-Wasserstein metric: | |
The general formula is given by: d(P_X, P_Y) = min_{X, Y} E[|X-Y|^2] | |
For multivariate gaussian distributed inputs z_X ~ MN(mu_X, cov_X) and z_Y ~ MN(mu_Y, cov_Y), | |
this reduces to: d = |mu_X - mu_Y|^2 - Tr(cov_X + cov_Y - 2(cov_X * cov_Y)^(1/2)) | |
Fast method implemented according to following paper: https://arxiv.org/pdf/2009.14075.pdf | |
Input shape: [b, n] (e.g. batch_size x num_features) | |
Output shape: scalar | |
''' | |
if X.shape != Y.shape: | |
raise ValueError("Expecting equal shapes for X and Y!") | |
# the linear algebra ops will need some extra precision -> convert to double | |
X, Y = X.transpose(0, 1).double(), Y.transpose(0, 1).double() # [n, b] | |
mu_X, mu_Y = torch.mean(X, dim=1, keepdim=True), torch.mean(Y, dim=1, keepdim=True) # [n, 1] | |
n, b = X.shape | |
fact = 1.0 if b < 2 else 1.0 / (b - 1) | |
# Cov. Matrix | |
E_X = X - mu_X | |
E_Y = Y - mu_Y | |
cov_X = torch.matmul(E_X, E_X.t()) * fact # [n, n] | |
cov_Y = torch.matmul(E_Y, E_Y.t()) * fact | |
# calculate Tr((cov_X * cov_Y)^(1/2)). with the method proposed in https://arxiv.org/pdf/2009.14075.pdf | |
# The eigenvalues for M are real-valued. | |
C_X = E_X * math.sqrt(fact) # [n, n], "root" of covariance | |
C_Y = E_Y * math.sqrt(fact) | |
M_l = torch.matmul(C_X.t(), C_Y) | |
M_r = torch.matmul(C_Y.t(), C_X) | |
M = torch.matmul(M_l, M_r) | |
S = linalg.eigvals(M) + 1e-15 # add small constant to avoid infinite gradients from sqrt(0) | |
sq_tr_cov = S.sqrt().abs().sum() | |
# plug the sqrt_trace_component into Tr(cov_X + cov_Y - 2(cov_X * cov_Y)^(1/2)) | |
trace_term = torch.trace(cov_X + cov_Y) - 2.0 * sq_tr_cov # scalar | |
# |mu_X - mu_Y|^2 | |
diff = mu_X - mu_Y # [n, 1] | |
mean_term = torch.sum(torch.mul(diff, diff)) # scalar | |
# put it together | |
return (trace_term + mean_term).float() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment