Created
February 9, 2019 16:57
-
-
Save flxai/ec2894f64a73ba28997b35643a464e80 to your computer and use it in GitHub Desktop.
compare-emd-loss-wasserstein-distance
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy.stats import wasserstein_distance | |
from emd import EMDLoss | |
def gaussian(x, mu, sig): | |
return np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.))) | |
a = gaussian(np.linspace(0, 10, 30), 5, 1) | |
a = np.roll(a, -2) | |
b = np.roll(a, 4) | |
c = a * 1.5 | |
plt.plot(a, label='a') | |
plt.plot(b, label='b') | |
plt.plot(c, label='c') | |
plt.legend() | |
plt.show() | |
dist = EMDLoss() | |
print("\nOriginal shapes") | |
print(a.shape) | |
print(b.shape) | |
print(c.shape) | |
print("\nPyTorch EMDLoss") | |
print('a b') | |
print(wasserstein_distance(a, b)) | |
print('a c') | |
print(wasserstein_distance(a, c)) | |
print('b c') | |
print(wasserstein_distance(b, c)) | |
ta = torch.from_numpy(a[None, :, None]).cuda().double() | |
tb = torch.from_numpy(b[None, :, None]).cuda().double() | |
tc = torch.from_numpy(c[None, :, None]).cuda().double() | |
ta.requires_grad = True | |
tb.requires_grad = True | |
tc.requires_grad = True | |
cost1 = dist(ta, tb) | |
cost2 = dist(ta, tc) | |
cost3 = dist(tb, tc) | |
print("\nNew shapes") | |
print(ta.shape) | |
print(tb.shape) | |
print(tc.shape) | |
print("\nPyTorch EMDLoss") | |
print('a b') | |
print(cost1.detach().cpu().numpy()) | |
print('a c') | |
print(cost2.detach().cpu().numpy()) | |
print('b c') | |
print(cost3.detach().cpu().numpy()) | |
loss1 = torch.sum(cost1) | |
print(loss1) | |
loss1.backward() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original shapes | |
(30,) | |
(30,) | |
(30,) | |
PyTorch EMDLoss | |
a b | |
0.0 | |
a c | |
0.12115367548619291 | |
b c | |
0.12115367548619291 | |
New shapes | |
torch.Size([1, 30, 1]) | |
torch.Size([1, 30, 1]) | |
torch.Size([1, 30, 1]) | |
PyTorch EMDLoss | |
a b | |
[0.04349525] | |
a c | |
[0.46960036] | |
b c | |
[0.89565078] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment