Skip to content

Instantly share code, notes, and snippets.

@BarclayII
Last active October 3, 2017 02:39
Show Gist options
  • Save BarclayII/64fabd04e58b92c3d192fb8c02d2a2b0 to your computer and use it in GitHub Desktop.
Save BarclayII/64fabd04e58b92c3d192fb8c02d2a2b0 to your computer and use it in GitHub Desktop.
MMD
import torch as T
import numpy as NP
### Norm of vector difference
# Checker
def normdiff_assert(X, Y, normdiff):
norm = normdiff(X, Y)
for i in range(X.size()[2]):
for j in range(Y.size()[2]):
assert NP.allclose(
((X[:, :, i] - Y[:, :, j]).norm(2, 1) ** 2).numpy(),
norm[:, i, j].numpy()
)
# 4D solution
def normdiff_4d(X, Y):
XY = X.unsqueeze(3) - Y.unsqueeze(2)
return XY.norm(2, 1) ** 2
# 3D solution
def normdiff_3d(X, Y):
batch_size = X.size()[0]
nX = X.size()[2]
nY = Y.size()[2]
XX = X.permute(0, 2, 1).bmm(X)
XY = X.permute(0, 2, 1).bmm(Y)
YY = Y.permute(0, 2, 1).bmm(Y)
XX_diag_indices = T.arange(0, nX).long().view(1, nX, 1).expand(batch_size, nX, 1)
YY_diag_indices = T.arange(0, nY).long().view(1, 1, nY).expand(batch_size, 1, nY)
XX_diag = XX.gather(2, XX_diag_indices)
YY_diag = YY.gather(1, YY_diag_indices)
return XX_diag - XY * 2 + YY_diag
### MMD computation
def mmd_naive(X, Y, sigma=[1]):
'''
Computes sample-set-wise MMD.
Naive version to verify correctness.
Input:
X, Y: 3-D tensor (num_sample_sets, feature_dim, num_samples)
Output:
L: 1-D tensor of size (num_sample_sets,), where
L[i] = MMD^2(X[i], Y[i])
'''
nX = X.size()[2]
nY = Y.size()[2]
scale = [-1. / (2 * s) for s in sigma]
kXX = 0
kXY = 0
kYY = 0
for i in range(nX):
for j in range(nX):
for s in scale:
kXX += (s * (X[:, :, i] - X[:, :, j]).norm(2, 1) ** 2).exp()
for i in range(nX):
for j in range(nY):
for s in scale:
kXY += (s * (X[:, :, i] - Y[:, :, j]).norm(2, 1) ** 2).exp()
for i in range(nY):
for j in range(nY):
for s in scale:
kYY += (s * (Y[:, :, i] - Y[:, :, j]).norm(2, 1) ** 2).exp()
return kXX / (nX * nX) - kXY / (nX * nY) * 2 + kYY / (nY * nY)
def mmd_normdiff(X, Y, sigma=[1]):
'''
Vectorized version using normdiff_XX()
According to cProfile this guy is the fastest among all implementations.
Update: slowest on GPU
'''
nX = X.size()[2]
nY = Y.size()[2]
scale = -1. / (2 * T.Tensor(sigma)).view(1, 1, 1, -1)
kXX = (scale * normdiff_3d(X, X).unsqueeze(3)).exp().sum(3).sum(2).sum(1)
kXY = (scale * normdiff_3d(X, Y).unsqueeze(3)).exp().sum(3).sum(2).sum(1)
kYY = (scale * normdiff_3d(Y, Y).unsqueeze(3)).exp().sum(3).sum(2).sum(1)
return kXX / (nX * nX) - kXY / (nX * nY) * 2 + kYY / (nY * nY)
def mmd_oneshot(X, Y, sigma=[1]):
'''
Vectorized version without repeated calls of normdiff
According to cProfile this guy is the slowest except the naive version.
'''
batch_size = X.size()[0]
nX = X.size()[2]
nY = Y.size()[2]
scale = -1. / (2 * T.Tensor(sigma)).view(1, 1, 1, -1)
XX = X.permute(0, 2, 1).bmm(X)
XY = X.permute(0, 2, 1).bmm(Y)
YY = Y.permute(0, 2, 1).bmm(Y)
XX_diag_indices = T.arange(0, nX).long().view(1, nX, 1).expand(batch_size, nX, 1)
YY_diag_indices = T.arange(0, nY).long().view(1, 1, nY).expand(batch_size, 1, nY)
XX_diag = XX.gather(2, XX_diag_indices)
YY_diag = YY.gather(1, YY_diag_indices)
norm_XX = XX_diag.view(batch_size, 1, nX) - XX * 2 + XX_diag
norm_XY = XX_diag - XY * 2 + YY_diag
norm_YY = YY_diag - YY * 2 + YY_diag.view(batch_size, nY, 1)
kXX = (scale * norm_XX.unsqueeze(3)).exp().sum(3).sum(2).sum(1)
kXY = (scale * norm_XY.unsqueeze(3)).exp().sum(3).sum(2).sum(1)
kYY = (scale * norm_YY.unsqueeze(3)).exp().sum(3).sum(2).sum(1)
return kXX / (nX * nX) - kXY / (nX * nY) * 2 + kYY / (nY * nY)
def mmd_gmmn(X, Y, sigma=[1]):
'''
The implementation from GMMN.
Certainly this guy is faster since it only did one mm.
The original code looped over all sigma so I vectorized it.
Update: the fastest on GPU.
'''
batch_size = X.size()[0]
nX = X.size()[2]
nY = Y.size()[2]
s = T.zeros(nX + nY, 1)
s[:nX] = 1
s = s / nX - (1 - s) / nY
W = s
X = T.cat([X, Y], 2)
XX = X.permute(0, 2, 1).bmm(X)
XX_diag_indices = T.arange(0, nX + nY).long().view(1, nX + nY, 1).expand(batch_size, nX + nY, 1)
x = XX.gather(2, XX_diag_indices)
prod_mat = XX - 0.5 * x - 0.5 * x.view(batch_size, 1, nX + nY)
ww = W.mm(W.permute(1, 0))
sigma = T.Tensor(sigma).view(1, 1, 1, -1)
K = T.exp(1. / sigma * prod_mat.unsqueeze(3))
A = ww.unsqueeze(0).unsqueeze(3) * K
loss = A.sum(3).sum(2).sum(1)
return loss
X = T.randn(1, 256, 1000)
Y = T.randn(1, 256, 2000)
Z = T.randn(1, 256, 2000)
X = T.log(1 + T.exp(X))
Y = T.log(1 + T.exp(Y))
Z = Z * X.std() + X.mean()
sigma = [2, 5, 10, 20, 40, 80]
mmd_gmmn_xz = mmd_gmmn(X, Z, sigma).numpy()
#mmd_naive_xz = mmd_naive(X, Z, sigma).numpy()
mmd_normdiff_xz = mmd_normdiff(X, Z, sigma).numpy()
mmd_xz = mmd_oneshot(X, Z, sigma).numpy()
mmd_xy = mmd_oneshot(X, Y, sigma).numpy()
X4 = ((X - X.mean()) ** 4).mean()
Y4 = ((Y - Y.mean()) ** 4).mean()
Z4 = ((Z - Z.mean()) ** 4).mean()
print 'MMD between X and Z:', mmd_xz
print 'MMD between X and Y:', mmd_xy
print 'MMD acceptance:', (4 * 1 / NP.sqrt(1000)) * NP.sqrt(NP.log(1 / 0.05))
print 'Fourth moment diff of X-Z (and the one divided by std):', (X4 - Z4) ** 2, (X4 - Z4) ** 2 / X.std()
print 'Fourth moment diff of X-Y (and the one divided by std):', (X4 - Y4) ** 2, (X4 - Y4) ** 2 / X.std()
# Test run:
# MMD between X and Z: [ 0.00822765]
# MMD between X and Y: [ 0.00804275]
# MMD acceptance: 0.218933132204
# Fourth moment diff of X-Z (and the one divided by std): 0.0122624075138 0.0235590941978
# Fourth moment diff of X-Y (and the one divided by std): 2.47656061454e-05 4.75808072262e-05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment