Skip to content

Instantly share code, notes, and snippets.

@flxai
Created February 9, 2019 16:57
Show Gist options
  • Save flxai/ec2894f64a73ba28997b35643a464e80 to your computer and use it in GitHub Desktop.
Save flxai/ec2894f64a73ba28997b35643a464e80 to your computer and use it in GitHub Desktop.
compare-emd-loss-wasserstein-distance
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import wasserstein_distance
from emd import EMDLoss
def gaussian(x, mu, sig):
return np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.)))
a = gaussian(np.linspace(0, 10, 30), 5, 1)
a = np.roll(a, -2)
b = np.roll(a, 4)
c = a * 1.5
plt.plot(a, label='a')
plt.plot(b, label='b')
plt.plot(c, label='c')
plt.legend()
plt.show()
dist = EMDLoss()
print("\nOriginal shapes")
print(a.shape)
print(b.shape)
print(c.shape)
print("\nPyTorch EMDLoss")
print('a b')
print(wasserstein_distance(a, b))
print('a c')
print(wasserstein_distance(a, c))
print('b c')
print(wasserstein_distance(b, c))
ta = torch.from_numpy(a[None, :, None]).cuda().double()
tb = torch.from_numpy(b[None, :, None]).cuda().double()
tc = torch.from_numpy(c[None, :, None]).cuda().double()
ta.requires_grad = True
tb.requires_grad = True
tc.requires_grad = True
cost1 = dist(ta, tb)
cost2 = dist(ta, tc)
cost3 = dist(tb, tc)
print("\nNew shapes")
print(ta.shape)
print(tb.shape)
print(tc.shape)
print("\nPyTorch EMDLoss")
print('a b')
print(cost1.detach().cpu().numpy())
print('a c')
print(cost2.detach().cpu().numpy())
print('b c')
print(cost3.detach().cpu().numpy())
loss1 = torch.sum(cost1)
print(loss1)
loss1.backward()
Original shapes
(30,)
(30,)
(30,)
PyTorch EMDLoss
a b
0.0
a c
0.12115367548619291
b c
0.12115367548619291
New shapes
torch.Size([1, 30, 1])
torch.Size([1, 30, 1])
torch.Size([1, 30, 1])
PyTorch EMDLoss
a b
[0.04349525]
a c
[0.46960036]
b c
[0.89565078]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment