Skip to content

Instantly share code, notes, and snippets.

View tomaarsen's full-sized avatar

Tom Aarsen tomaarsen

View GitHub Profile
@Flunzmas
Flunzmas / calc_2_wasserstein_dist.py
Last active April 1, 2025 11:11
Differentiable 2-Wasserstein Distance in PyTorch
import math
import torch
import torch.linalg as linalg
def calculate_2_wasserstein_dist(X, Y):
'''
Calulates the two components of the 2-Wasserstein metric:
The general formula is given by: d(P_X, P_Y) = min_{X, Y} E[|X-Y|^2]
For multivariate gaussian distributed inputs z_X ~ MN(mu_X, cov_X) and z_Y ~ MN(mu_Y, cov_Y),