Skip to content

Instantly share code, notes, and snippets.

@rahulbhadani
Created January 27, 2022 03:10
Show Gist options
  • Save rahulbhadani/b0100f4ed895e3cd38165fdaf5b532a6 to your computer and use it in GitHub Desktop.
Save rahulbhadani/b0100f4ed895e3cd38165fdaf5b532a6 to your computer and use it in GitHub Desktop.
GW Distance
import scipy as sp
import numpy as np
import matplotlib.pylab as pl
from mpl_toolkits.mplot3d import Axes3D # noqa
import ot
n_samples = 50 # nb samples
mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
mu_t = np.array([4, 4, 4])
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
fig = pl.figure()
ax1 = fig.add_subplot(121)
ax1.plot(xs[:, 0], xs[:, 1], '+g', label='Source Samples')
ax2 = fig.add_subplot(122, projection='3d')
ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
pl.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment