Created
          February 9, 2019 16:57 
        
      - 
      
- 
        Save flxai/ec2894f64a73ba28997b35643a464e80 to your computer and use it in GitHub Desktop. 
    compare-emd-loss-wasserstein-distance
  
        
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from scipy.stats import wasserstein_distance | |
| from emd import EMDLoss | |
| def gaussian(x, mu, sig): | |
| return np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.))) | |
| a = gaussian(np.linspace(0, 10, 30), 5, 1) | |
| a = np.roll(a, -2) | |
| b = np.roll(a, 4) | |
| c = a * 1.5 | |
| plt.plot(a, label='a') | |
| plt.plot(b, label='b') | |
| plt.plot(c, label='c') | |
| plt.legend() | |
| plt.show() | |
| dist = EMDLoss() | |
| print("\nOriginal shapes") | |
| print(a.shape) | |
| print(b.shape) | |
| print(c.shape) | |
| print("\nPyTorch EMDLoss") | |
| print('a b') | |
| print(wasserstein_distance(a, b)) | |
| print('a c') | |
| print(wasserstein_distance(a, c)) | |
| print('b c') | |
| print(wasserstein_distance(b, c)) | |
| ta = torch.from_numpy(a[None, :, None]).cuda().double() | |
| tb = torch.from_numpy(b[None, :, None]).cuda().double() | |
| tc = torch.from_numpy(c[None, :, None]).cuda().double() | |
| ta.requires_grad = True | |
| tb.requires_grad = True | |
| tc.requires_grad = True | |
| cost1 = dist(ta, tb) | |
| cost2 = dist(ta, tc) | |
| cost3 = dist(tb, tc) | |
| print("\nNew shapes") | |
| print(ta.shape) | |
| print(tb.shape) | |
| print(tc.shape) | |
| print("\nPyTorch EMDLoss") | |
| print('a b') | |
| print(cost1.detach().cpu().numpy()) | |
| print('a c') | |
| print(cost2.detach().cpu().numpy()) | |
| print('b c') | |
| print(cost3.detach().cpu().numpy()) | |
| loss1 = torch.sum(cost1) | |
| print(loss1) | |
| loss1.backward() | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | Original shapes | |
| (30,) | |
| (30,) | |
| (30,) | |
| PyTorch EMDLoss | |
| a b | |
| 0.0 | |
| a c | |
| 0.12115367548619291 | |
| b c | |
| 0.12115367548619291 | |
| New shapes | |
| torch.Size([1, 30, 1]) | |
| torch.Size([1, 30, 1]) | |
| torch.Size([1, 30, 1]) | |
| PyTorch EMDLoss | |
| a b | |
| [0.04349525] | |
| a c | |
| [0.46960036] | |
| b c | |
| [0.89565078] | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment