Skip to content

Instantly share code, notes, and snippets.

@ctralie
Created November 14, 2018 01:43
Show Gist options
  • Save ctralie/66352ae6ab06c009f02c705385a446f3 to your computer and use it in GitHub Desktop.
Save ctralie/66352ae6ab06c009f02c705385a446f3 to your computer and use it in GitHub Desktop.
2D Histogram Wasserstein Distance via POT Library
"""
Programmer: Chris Tralie
Purpose: To use the POT library (https://github.com/rflamary/POT)
to compute the Entropic regularized Wasserstein distance
between points on a 2D grid
"""
import numpy as np
import matplotlib.pyplot as plt
import ot
def testMovingDisc():
"""
Show optimal transport on a moving disc in a 50x50 grid
"""
## Step 1: Setup problem
pix = np.linspace(-1, 1, 80)
# Setup grid
X, Y = np.meshgrid(pix, pix)
# Compute pariwise distances between points on 2D grid so we know
# how to score the Wasserstein distance
coords = np.array([X.flatten(), Y.flatten()]).T
coordsSqr = np.sum(coords**2, 1)
M = coordsSqr[:, None] + coordsSqr[None, :] - 2*coords.dot(coords.T)
M[M < 0] = 0
M = np.sqrt(M)
ts = np.linspace(-0.8, 0.8, 100)
## Step 2: Compute L2 distances and Wasserstein
Images = []
radius = 0.2
L2Dists = [0.0]
WassDists = [0.0]
for i, t in enumerate(ts):
I = 1e-5 + np.array((X-t)**2 + (Y-t)**2 < radius**2, dtype=float)
I /= np.sum(I)
Images.append(I)
if i > 0:
L2Dists.append(np.sqrt(np.sum((I-Images[0])**2)))
wass = ot.sinkhorn2(Images[0].flatten(), I.flatten(), M, 1.0)
print(wass)
WassDists.append(wass)
## Step 3: Make Animation
L2Dist = np.array(L2Dists)
WassDists = np.array(WassDists)
I0 = Images[0]
plt.figure(figsize=(15, 5))
displacements = np.sqrt(2)*(ts - ts[0])
for i, I in enumerate(Images):
plt.clf()
D = np.concatenate((I0[:, :, None], I[:, :, None], 0*I[:, :, None]), 2)
D = D*255/np.max(I0)
D = np.array(D, dtype=np.uint8)
plt.subplot(131)
plt.imshow(D, extent = (pix[0], pix[-1], pix[-1], pix[0]))
plt.subplot(132)
plt.plot(displacements, L2Dists)
plt.stem([displacements[i]], [L2Dists[i]])
plt.xlabel("Displacements")
plt.ylabel("L2 Dist")
plt.title("L2 Dist")
plt.subplot(133)
plt.plot(displacements, WassDists)
plt.stem([displacements[i]], [WassDists[i]])
plt.xlabel("Displacements")
plt.ylabel("Wasserstein Dist")
plt.title("Wasserstein Dist")
plt.savefig("%i.png"%i, bbox_inches='tight')
if __name__ == '__main__':
testMovingDisc()
@ctralie
Copy link
Author

ctralie commented Jul 18, 2019

Great questions. I unfortunately have not used this library beyond this example (which was to help a student), so I might not be the best person to ask, but I will try

Yes, the nonlinear part is strange, but I wonder if it's just because of the shape and discretization of the discs. It would be worth trying this with two squares that move past each other only in the x direction, for example.

That's strange that it's not equal to zero from I to itself, because it should be

Yes, I was as confused as you about problems when there are zeros, and that is indeed why I added 1e-5

@raphiol
Copy link

raphiol commented Jul 19, 2019

Thank you for your answer. I tryed with small squares of 4 pixels and the problem was already here...
I will continue to investigate !

@raphiol
Copy link

raphiol commented Jul 29, 2019

I finally have the answer to my questions, so I thought I would give them to you too:

The function ot.sinkhorn2 is actually an aproximation of the optimal transport problem. This is why it causes some problems.
The exact value for the Wasserstein distance is obtained by using the ot.emd2 function instead.
It is a bit longer and the number of iterations must be increased, but it works !

Regards

@ctralie
Copy link
Author

ctralie commented Jul 29, 2019 via email

@justinblaber
Copy link

This gist is perfect for describing earth movers distance. Thank you!

@ctralie
Copy link
Author

ctralie commented Mar 23, 2020 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment