-
-
Save ctralie/66352ae6ab06c009f02c705385a446f3 to your computer and use it in GitHub Desktop.
""" | |
Programmer: Chris Tralie | |
Purpose: To use the POT library (https://github.com/rflamary/POT) | |
to compute the Entropic regularized Wasserstein distance | |
between points on a 2D grid | |
""" | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import ot | |
def testMovingDisc(): | |
""" | |
Show optimal transport on a moving disc in a 50x50 grid | |
""" | |
## Step 1: Setup problem | |
pix = np.linspace(-1, 1, 80) | |
# Setup grid | |
X, Y = np.meshgrid(pix, pix) | |
# Compute pariwise distances between points on 2D grid so we know | |
# how to score the Wasserstein distance | |
coords = np.array([X.flatten(), Y.flatten()]).T | |
coordsSqr = np.sum(coords**2, 1) | |
M = coordsSqr[:, None] + coordsSqr[None, :] - 2*coords.dot(coords.T) | |
M[M < 0] = 0 | |
M = np.sqrt(M) | |
ts = np.linspace(-0.8, 0.8, 100) | |
## Step 2: Compute L2 distances and Wasserstein | |
Images = [] | |
radius = 0.2 | |
L2Dists = [0.0] | |
WassDists = [0.0] | |
for i, t in enumerate(ts): | |
I = 1e-5 + np.array((X-t)**2 + (Y-t)**2 < radius**2, dtype=float) | |
I /= np.sum(I) | |
Images.append(I) | |
if i > 0: | |
L2Dists.append(np.sqrt(np.sum((I-Images[0])**2))) | |
wass = ot.sinkhorn2(Images[0].flatten(), I.flatten(), M, 1.0) | |
print(wass) | |
WassDists.append(wass) | |
## Step 3: Make Animation | |
L2Dist = np.array(L2Dists) | |
WassDists = np.array(WassDists) | |
I0 = Images[0] | |
plt.figure(figsize=(15, 5)) | |
displacements = np.sqrt(2)*(ts - ts[0]) | |
for i, I in enumerate(Images): | |
plt.clf() | |
D = np.concatenate((I0[:, :, None], I[:, :, None], 0*I[:, :, None]), 2) | |
D = D*255/np.max(I0) | |
D = np.array(D, dtype=np.uint8) | |
plt.subplot(131) | |
plt.imshow(D, extent = (pix[0], pix[-1], pix[-1], pix[0])) | |
plt.subplot(132) | |
plt.plot(displacements, L2Dists) | |
plt.stem([displacements[i]], [L2Dists[i]]) | |
plt.xlabel("Displacements") | |
plt.ylabel("L2 Dist") | |
plt.title("L2 Dist") | |
plt.subplot(133) | |
plt.plot(displacements, WassDists) | |
plt.stem([displacements[i]], [WassDists[i]]) | |
plt.xlabel("Displacements") | |
plt.ylabel("Wasserstein Dist") | |
plt.title("Wasserstein Dist") | |
plt.savefig("%i.png"%i, bbox_inches='tight') | |
if __name__ == '__main__': | |
testMovingDisc() |
I am trying to use the Wasserstein distance in my studies and I have found your script in my research.
I have two questions about it, maybe you can help me to better understand what's happening here.
- Do you know why the calcul of the Wasserstein distance is non-linear when the two discs are closed from each other ? Also ot.sinkhorn2(I.flatten(), I.flatten(), M, 1.) is non equal to zero, is that normal for a distance ?
- Do you know why when the I matrix are filled with some zeros, their is problem to calculate the Wasserstein distance ? I see you added 1e-5, I guess this is to avoid that...
Thank you
Great questions. I unfortunately have not used this library beyond this example (which was to help a student), so I might not be the best person to ask, but I will try
Yes, the nonlinear part is strange, but I wonder if it's just because of the shape and discretization of the discs. It would be worth trying this with two squares that move past each other only in the x direction, for example.
That's strange that it's not equal to zero from I to itself, because it should be
Yes, I was as confused as you about problems when there are zeros, and that is indeed why I added 1e-5
Thank you for your answer. I tryed with small squares of 4 pixels and the problem was already here...
I will continue to investigate !
I finally have the answer to my questions, so I thought I would give them to you too:
The function ot.sinkhorn2 is actually an aproximation of the optimal transport problem. This is why it causes some problems.
The exact value for the Wasserstein distance is obtained by using the ot.emd2 function instead.
It is a bit longer and the number of iterations must be increased, but it works !
Regards
This gist is perfect for describing earth movers distance. Thank you!
Computing the Wasserstein distance between two sampled discs as one of them moves away from the other one. The L2 distance, by comparison, saturates very quickly (please excuse the aliasing in this simple example)