Created
November 14, 2018 01:43
-
-
Save ctralie/66352ae6ab06c009f02c705385a446f3 to your computer and use it in GitHub Desktop.
2D Histogram Wasserstein Distance via POT Library
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Programmer: Chris Tralie | |
Purpose: To use the POT library (https://github.com/rflamary/POT) | |
to compute the Entropic regularized Wasserstein distance | |
between points on a 2D grid | |
""" | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import ot | |
def testMovingDisc(): | |
""" | |
Show optimal transport on a moving disc in a 50x50 grid | |
""" | |
## Step 1: Setup problem | |
pix = np.linspace(-1, 1, 80) | |
# Setup grid | |
X, Y = np.meshgrid(pix, pix) | |
# Compute pariwise distances between points on 2D grid so we know | |
# how to score the Wasserstein distance | |
coords = np.array([X.flatten(), Y.flatten()]).T | |
coordsSqr = np.sum(coords**2, 1) | |
M = coordsSqr[:, None] + coordsSqr[None, :] - 2*coords.dot(coords.T) | |
M[M < 0] = 0 | |
M = np.sqrt(M) | |
ts = np.linspace(-0.8, 0.8, 100) | |
## Step 2: Compute L2 distances and Wasserstein | |
Images = [] | |
radius = 0.2 | |
L2Dists = [0.0] | |
WassDists = [0.0] | |
for i, t in enumerate(ts): | |
I = 1e-5 + np.array((X-t)**2 + (Y-t)**2 < radius**2, dtype=float) | |
I /= np.sum(I) | |
Images.append(I) | |
if i > 0: | |
L2Dists.append(np.sqrt(np.sum((I-Images[0])**2))) | |
wass = ot.sinkhorn2(Images[0].flatten(), I.flatten(), M, 1.0) | |
print(wass) | |
WassDists.append(wass) | |
## Step 3: Make Animation | |
L2Dist = np.array(L2Dists) | |
WassDists = np.array(WassDists) | |
I0 = Images[0] | |
plt.figure(figsize=(15, 5)) | |
displacements = np.sqrt(2)*(ts - ts[0]) | |
for i, I in enumerate(Images): | |
plt.clf() | |
D = np.concatenate((I0[:, :, None], I[:, :, None], 0*I[:, :, None]), 2) | |
D = D*255/np.max(I0) | |
D = np.array(D, dtype=np.uint8) | |
plt.subplot(131) | |
plt.imshow(D, extent = (pix[0], pix[-1], pix[-1], pix[0])) | |
plt.subplot(132) | |
plt.plot(displacements, L2Dists) | |
plt.stem([displacements[i]], [L2Dists[i]]) | |
plt.xlabel("Displacements") | |
plt.ylabel("L2 Dist") | |
plt.title("L2 Dist") | |
plt.subplot(133) | |
plt.plot(displacements, WassDists) | |
plt.stem([displacements[i]], [WassDists[i]]) | |
plt.xlabel("Displacements") | |
plt.ylabel("Wasserstein Dist") | |
plt.title("Wasserstein Dist") | |
plt.savefig("%i.png"%i, bbox_inches='tight') | |
if __name__ == '__main__': | |
testMovingDisc() |
Author
ctralie
commented
Jul 29, 2019
via email
That is great to know, thank you so much for the update!
…On Mon, Jul 29, 2019, 9:16 AM raphiol ***@***.***> wrote:
I finally have the answer to my questions, so I thought I would give them
to you too:
The function ot.sinkhorn2 is actually an aproximation of the optimal
transport problem. This is why it causes some problems.
The exact value for the Wasserstein distance is obtained by using the
ot.emd2 function instead.
It is a bit longer and the number of iterations must be increased, but it
works !
Regards
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<https://gist.github.com/66352ae6ab06c009f02c705385a446f3?email_source=notifications&email_token=AAJWDZXAFPXUNIS62VXKFRTQB3UUNA5CNFSM4IE2C44KYY3PNVWWK3TUL52HS4DFVNDWS43UINXW23LFNZ2KUY3PNVWWK3TUL5UWJTQAFWDZQ#gistcomment-2983832>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAJWDZSTILXVWH6V62C5NMLQB3UUNANCNFSM4IE2C44A>
.
This gist is perfect for describing earth movers distance. Thank you!
Yaaay so glad it helped! Usually when I get these notifications there's a
bug, so I'm happy to see a fun one for a change
…On Mon, Mar 23, 2020 at 5:40 PM Justin ***@***.***> wrote:
***@***.**** commented on this gist.
------------------------------
This gist is perfect for describing earth movers distance. Thank you!
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<https://gist.github.com/66352ae6ab06c009f02c705385a446f3#gistcomment-3224778>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJWDZXKDMCRTLPKGZNYYBLRI7JLXANCNFSM4IE2C44A>
.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment