Last active
May 13, 2021 19:21
-
-
Save jcowles/1de8b12c38603ce932b0154bc6d59d60 to your computer and use it in GitHub Desktop.
Convolutional Wasserstein Distance & Barycenter implementation in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CC0, 2021 Jeremy Cowles, no rights reserved. | |
# | |
# The following is a 1-dimensional implementation of two core algorithms from the paper | |
# | |
# Convolutional Wasserstein Distances: Efficient Optimal Transportation on Geometric Domains | |
# https://people.csail.mit.edu/jsolomon/assets/convolutional_w2.compressed.pdf | |
# | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import math | |
# The following sets up a 1D gaussian, but can be extended to multiple dimensions | |
# by modifying _create_window(). | |
def _gauss(window_size, sigma) -> torch.Tensor: | |
""" | |
Create weights for a discrete gaussian kernel. | |
""" | |
gauss = torch.tensor( | |
[math.exp(-(x - 0.5 * (window_size - 1)) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)], | |
dtype=torch.float, | |
requires_grad=False, | |
) | |
return gauss / gauss.sum() | |
def _create_window(window_size, sigma = None) -> torch.Tensor: | |
""" | |
Create 1D weights for a gaussian convolution. | |
If sigma is not provided, it will be computed from the window size. | |
""" | |
if not sigma: | |
# Compute sigma in terms of pixels. | |
# Auto-sigma is great for image-based convolutions, but makes less sense in this context, | |
# since sigma directly controls the error of the wasserstein approximations. | |
# | |
# Expect this default sigma value to be blurry. | |
sigmas_per_pixel = 2.5 | |
sigma = 0.5 * (window_size - 1) / sigmas_per_pixel | |
window = _gauss(window_size, sigma).unsqueeze(1) | |
# 2D window could be constructed here. | |
return window, sigma | |
def wasserstein_distance(mu0, mu1): | |
""" | |
mu0, mu1: The source and target distributions | |
Returns: Convolutional wasserstein distance between the two distributions. | |
""" | |
# Convolutional window size. | |
k = 5 | |
# Smaller sigma = more accurate distance, but requires more iterations. | |
Hw,sigma = _create_window(k, sigma=0.2) | |
Hw = Hw.reshape([1,1,Hw.shape[0]]) | |
H = lambda x: F.conv1d(x, Hw, padding=k // 2)[0] | |
gamma = sigma * sigma | |
a = 1.0 / mu0.shape[-1] | |
v = torch.ones(mu0.shape) | |
w = torch.ones(mu0.shape) | |
for i in range(40): | |
v = mu0 / H(a * w) | |
w = mu1 / H(a * v) | |
return gamma * (a * (mu0 * v.log() + mu1 * w.log())).sum() | |
def wasserstein_barycenter(mu_s, weights): | |
""" | |
mu_s: The distribution endpoints | |
weights: The weights for each distribution | |
Returns: a new distribution, an interpolation of the endpoints according to given weights. | |
""" | |
assert len(mu_s) == len(weights) | |
# Convolutional window size. | |
k = 5 | |
# Smaller sigma = shaper interpolation, but requires more iterations. | |
Hw,sigma = _create_window(k, sigma=0.1) | |
Hw = Hw.reshape([1,1,Hw.shape[0]]) | |
H = lambda x: F.conv1d(x.unsqueeze(0), Hw, padding=k // 2)[0] | |
gamma = sigma * sigma | |
shape = mu_s.shape | |
a = 1.0 / shape[-1] | |
v = torch.ones(mu_s.shape) | |
w = torch.ones(mu_s.shape) | |
d = torch.ones(mu_s.shape) | |
for i in range(5): | |
mu = torch.ones(1, 1, shape[-1]) | |
# Constraint 1: pi_i marginalizes to mu_i in one direction | |
for i in range(len(weights)): | |
w[i] = mu_s[i] / H(a * v[i]) | |
d[i] = v[i] * H(a * w[i]) | |
mu = mu * torch.pow(d[i], weights[i]) | |
# NOTE: Entropic sharpening not implemented. | |
# Constraint 2: all pi_s marginalize to the same mu in the other direction | |
for i in range(len(weights)): | |
v[i] = v[i] * mu / d[i] | |
return mu | |
def plot(wd, lin, init, delay): | |
import matplotlib.pyplot as plt | |
if init: | |
plt.ion() | |
plt.show() | |
plt.clf() | |
plt.title("Wasserstein vs. Linear Interpolation") | |
colors = {'Wasserstein':'royalblue', 'Linear':'darkorange'} | |
labels = list(colors.keys()) | |
handles = [plt.Rectangle((0,0),1,1, color=colors[label]) for label in labels] | |
plt.legend(handles, labels) | |
plt.ylim(0,1) | |
plt.bar(range(1,17,2),height=wd.squeeze(),color=colors["Wasserstein"]) | |
plt.bar(range(2,17,2),height=lin.squeeze(), color=colors["Linear"]) | |
plt.pause(delay) | |
# Compute wasserstein distance between two distributions. | |
wd = wasserstein_distance(torch.tensor([[[.50, .25, .01, .01, .01, .01, .20, .01]]]), | |
torch.tensor([[[.20, .01, .01, .01, .01, .74, .01, .01]]])) | |
print("Wasserstein Distance:", wd) | |
# Interpolate between two distributions, the second tensor should be barycentric weights. | |
endpoints = torch.tensor([ | |
[[.50, .25, .01, .01, .01, .01, .20, .01]], | |
[[.20, .01, .01, .01, .01, .74, .01, .01]] | |
]) | |
# Wasserstein interpolation. | |
steps = 50 | |
while True: | |
init = True | |
for i in range(steps): | |
v = i / (steps-1) | |
weights = torch.tensor([1-v, v]) | |
mu = wasserstein_barycenter(endpoints, weights) | |
linear = endpoints[0] * weights[0] + endpoints[1] * weights[1] | |
plot(mu, linear, init, (1.0/steps) * 2) | |
init = False | |
print(mu) |
Author
jcowles
commented
Feb 8, 2021
•
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment