-
-
Save ctralie/66352ae6ab06c009f02c705385a446f3 to your computer and use it in GitHub Desktop.
""" | |
Programmer: Chris Tralie | |
Purpose: To use the POT library (https://github.com/rflamary/POT) | |
to compute the Entropic regularized Wasserstein distance | |
between points on a 2D grid | |
""" | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import ot | |
def testMovingDisc(): | |
""" | |
Show optimal transport on a moving disc in a 50x50 grid | |
""" | |
## Step 1: Setup problem | |
pix = np.linspace(-1, 1, 80) | |
# Setup grid | |
X, Y = np.meshgrid(pix, pix) | |
# Compute pariwise distances between points on 2D grid so we know | |
# how to score the Wasserstein distance | |
coords = np.array([X.flatten(), Y.flatten()]).T | |
coordsSqr = np.sum(coords**2, 1) | |
M = coordsSqr[:, None] + coordsSqr[None, :] - 2*coords.dot(coords.T) | |
M[M < 0] = 0 | |
M = np.sqrt(M) | |
ts = np.linspace(-0.8, 0.8, 100) | |
## Step 2: Compute L2 distances and Wasserstein | |
Images = [] | |
radius = 0.2 | |
L2Dists = [0.0] | |
WassDists = [0.0] | |
for i, t in enumerate(ts): | |
I = 1e-5 + np.array((X-t)**2 + (Y-t)**2 < radius**2, dtype=float) | |
I /= np.sum(I) | |
Images.append(I) | |
if i > 0: | |
L2Dists.append(np.sqrt(np.sum((I-Images[0])**2))) | |
wass = ot.sinkhorn2(Images[0].flatten(), I.flatten(), M, 1.0) | |
print(wass) | |
WassDists.append(wass) | |
## Step 3: Make Animation | |
L2Dist = np.array(L2Dists) | |
WassDists = np.array(WassDists) | |
I0 = Images[0] | |
plt.figure(figsize=(15, 5)) | |
displacements = np.sqrt(2)*(ts - ts[0]) | |
for i, I in enumerate(Images): | |
plt.clf() | |
D = np.concatenate((I0[:, :, None], I[:, :, None], 0*I[:, :, None]), 2) | |
D = D*255/np.max(I0) | |
D = np.array(D, dtype=np.uint8) | |
plt.subplot(131) | |
plt.imshow(D, extent = (pix[0], pix[-1], pix[-1], pix[0])) | |
plt.subplot(132) | |
plt.plot(displacements, L2Dists) | |
plt.stem([displacements[i]], [L2Dists[i]]) | |
plt.xlabel("Displacements") | |
plt.ylabel("L2 Dist") | |
plt.title("L2 Dist") | |
plt.subplot(133) | |
plt.plot(displacements, WassDists) | |
plt.stem([displacements[i]], [WassDists[i]]) | |
plt.xlabel("Displacements") | |
plt.ylabel("Wasserstein Dist") | |
plt.title("Wasserstein Dist") | |
plt.savefig("%i.png"%i, bbox_inches='tight') | |
if __name__ == '__main__': | |
testMovingDisc() |
Great questions. I unfortunately have not used this library beyond this example (which was to help a student), so I might not be the best person to ask, but I will try
Yes, the nonlinear part is strange, but I wonder if it's just because of the shape and discretization of the discs. It would be worth trying this with two squares that move past each other only in the x direction, for example.
That's strange that it's not equal to zero from I to itself, because it should be
Yes, I was as confused as you about problems when there are zeros, and that is indeed why I added 1e-5
Thank you for your answer. I tryed with small squares of 4 pixels and the problem was already here...
I will continue to investigate !
I finally have the answer to my questions, so I thought I would give them to you too:
The function ot.sinkhorn2 is actually an aproximation of the optimal transport problem. This is why it causes some problems.
The exact value for the Wasserstein distance is obtained by using the ot.emd2 function instead.
It is a bit longer and the number of iterations must be increased, but it works !
Regards
This gist is perfect for describing earth movers distance. Thank you!
I am trying to use the Wasserstein distance in my studies and I have found your script in my research.
I have two questions about it, maybe you can help me to better understand what's happening here.
Thank you