Created
November 14, 2018 01:43
-
-
Save ctralie/66352ae6ab06c009f02c705385a446f3 to your computer and use it in GitHub Desktop.
2D Histogram Wasserstein Distance via POT Library
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Programmer: Chris Tralie | |
Purpose: To use the POT library (https://github.com/rflamary/POT) | |
to compute the Entropic regularized Wasserstein distance | |
between points on a 2D grid | |
""" | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import ot | |
def testMovingDisc(): | |
""" | |
Show optimal transport on a moving disc in a 50x50 grid | |
""" | |
## Step 1: Setup problem | |
pix = np.linspace(-1, 1, 80) | |
# Setup grid | |
X, Y = np.meshgrid(pix, pix) | |
# Compute pariwise distances between points on 2D grid so we know | |
# how to score the Wasserstein distance | |
coords = np.array([X.flatten(), Y.flatten()]).T | |
coordsSqr = np.sum(coords**2, 1) | |
M = coordsSqr[:, None] + coordsSqr[None, :] - 2*coords.dot(coords.T) | |
M[M < 0] = 0 | |
M = np.sqrt(M) | |
ts = np.linspace(-0.8, 0.8, 100) | |
## Step 2: Compute L2 distances and Wasserstein | |
Images = [] | |
radius = 0.2 | |
L2Dists = [0.0] | |
WassDists = [0.0] | |
for i, t in enumerate(ts): | |
I = 1e-5 + np.array((X-t)**2 + (Y-t)**2 < radius**2, dtype=float) | |
I /= np.sum(I) | |
Images.append(I) | |
if i > 0: | |
L2Dists.append(np.sqrt(np.sum((I-Images[0])**2))) | |
wass = ot.sinkhorn2(Images[0].flatten(), I.flatten(), M, 1.0) | |
print(wass) | |
WassDists.append(wass) | |
## Step 3: Make Animation | |
L2Dist = np.array(L2Dists) | |
WassDists = np.array(WassDists) | |
I0 = Images[0] | |
plt.figure(figsize=(15, 5)) | |
displacements = np.sqrt(2)*(ts - ts[0]) | |
for i, I in enumerate(Images): | |
plt.clf() | |
D = np.concatenate((I0[:, :, None], I[:, :, None], 0*I[:, :, None]), 2) | |
D = D*255/np.max(I0) | |
D = np.array(D, dtype=np.uint8) | |
plt.subplot(131) | |
plt.imshow(D, extent = (pix[0], pix[-1], pix[-1], pix[0])) | |
plt.subplot(132) | |
plt.plot(displacements, L2Dists) | |
plt.stem([displacements[i]], [L2Dists[i]]) | |
plt.xlabel("Displacements") | |
plt.ylabel("L2 Dist") | |
plt.title("L2 Dist") | |
plt.subplot(133) | |
plt.plot(displacements, WassDists) | |
plt.stem([displacements[i]], [WassDists[i]]) | |
plt.xlabel("Displacements") | |
plt.ylabel("Wasserstein Dist") | |
plt.title("Wasserstein Dist") | |
plt.savefig("%i.png"%i, bbox_inches='tight') | |
if __name__ == '__main__': | |
testMovingDisc() |
Yaaay so glad it helped! Usually when I get these notifications there's a
bug, so I'm happy to see a fun one for a change
…On Mon, Mar 23, 2020 at 5:40 PM Justin ***@***.***> wrote:
***@***.**** commented on this gist.
------------------------------
This gist is perfect for describing earth movers distance. Thank you!
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<https://gist.github.com/66352ae6ab06c009f02c705385a446f3#gistcomment-3224778>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJWDZXKDMCRTLPKGZNYYBLRI7JLXANCNFSM4IE2C44A>
.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This gist is perfect for describing earth movers distance. Thank you!