Skip to content

Instantly share code, notes, and snippets.

@ctralie
Created November 14, 2018 01:43
Show Gist options
  • Save ctralie/66352ae6ab06c009f02c705385a446f3 to your computer and use it in GitHub Desktop.
Save ctralie/66352ae6ab06c009f02c705385a446f3 to your computer and use it in GitHub Desktop.
2D Histogram Wasserstein Distance via POT Library
"""
Programmer: Chris Tralie
Purpose: To use the POT library (https://github.com/rflamary/POT)
to compute the Entropic regularized Wasserstein distance
between points on a 2D grid
"""
import numpy as np
import matplotlib.pyplot as plt
import ot
def testMovingDisc():
"""
Show optimal transport on a moving disc in a 50x50 grid
"""
## Step 1: Setup problem
pix = np.linspace(-1, 1, 80)
# Setup grid
X, Y = np.meshgrid(pix, pix)
# Compute pariwise distances between points on 2D grid so we know
# how to score the Wasserstein distance
coords = np.array([X.flatten(), Y.flatten()]).T
coordsSqr = np.sum(coords**2, 1)
M = coordsSqr[:, None] + coordsSqr[None, :] - 2*coords.dot(coords.T)
M[M < 0] = 0
M = np.sqrt(M)
ts = np.linspace(-0.8, 0.8, 100)
## Step 2: Compute L2 distances and Wasserstein
Images = []
radius = 0.2
L2Dists = [0.0]
WassDists = [0.0]
for i, t in enumerate(ts):
I = 1e-5 + np.array((X-t)**2 + (Y-t)**2 < radius**2, dtype=float)
I /= np.sum(I)
Images.append(I)
if i > 0:
L2Dists.append(np.sqrt(np.sum((I-Images[0])**2)))
wass = ot.sinkhorn2(Images[0].flatten(), I.flatten(), M, 1.0)
print(wass)
WassDists.append(wass)
## Step 3: Make Animation
L2Dist = np.array(L2Dists)
WassDists = np.array(WassDists)
I0 = Images[0]
plt.figure(figsize=(15, 5))
displacements = np.sqrt(2)*(ts - ts[0])
for i, I in enumerate(Images):
plt.clf()
D = np.concatenate((I0[:, :, None], I[:, :, None], 0*I[:, :, None]), 2)
D = D*255/np.max(I0)
D = np.array(D, dtype=np.uint8)
plt.subplot(131)
plt.imshow(D, extent = (pix[0], pix[-1], pix[-1], pix[0]))
plt.subplot(132)
plt.plot(displacements, L2Dists)
plt.stem([displacements[i]], [L2Dists[i]])
plt.xlabel("Displacements")
plt.ylabel("L2 Dist")
plt.title("L2 Dist")
plt.subplot(133)
plt.plot(displacements, WassDists)
plt.stem([displacements[i]], [WassDists[i]])
plt.xlabel("Displacements")
plt.ylabel("Wasserstein Dist")
plt.title("Wasserstein Dist")
plt.savefig("%i.png"%i, bbox_inches='tight')
if __name__ == '__main__':
testMovingDisc()
@raphiol
Copy link

raphiol commented Jul 29, 2019

I finally have the answer to my questions, so I thought I would give them to you too:

The function ot.sinkhorn2 is actually an aproximation of the optimal transport problem. This is why it causes some problems.
The exact value for the Wasserstein distance is obtained by using the ot.emd2 function instead.
It is a bit longer and the number of iterations must be increased, but it works !

Regards

@ctralie
Copy link
Author

ctralie commented Jul 29, 2019 via email

@justinblaber
Copy link

This gist is perfect for describing earth movers distance. Thank you!

@ctralie
Copy link
Author

ctralie commented Mar 23, 2020 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment