Last active
October 3, 2017 02:39
-
-
Save BarclayII/64fabd04e58b92c3d192fb8c02d2a2b0 to your computer and use it in GitHub Desktop.
MMD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch as T | |
import numpy as NP | |
### Norm of vector difference | |
# Checker | |
def normdiff_assert(X, Y, normdiff): | |
norm = normdiff(X, Y) | |
for i in range(X.size()[2]): | |
for j in range(Y.size()[2]): | |
assert NP.allclose( | |
((X[:, :, i] - Y[:, :, j]).norm(2, 1) ** 2).numpy(), | |
norm[:, i, j].numpy() | |
) | |
# 4D solution | |
def normdiff_4d(X, Y): | |
XY = X.unsqueeze(3) - Y.unsqueeze(2) | |
return XY.norm(2, 1) ** 2 | |
# 3D solution | |
def normdiff_3d(X, Y): | |
batch_size = X.size()[0] | |
nX = X.size()[2] | |
nY = Y.size()[2] | |
XX = X.permute(0, 2, 1).bmm(X) | |
XY = X.permute(0, 2, 1).bmm(Y) | |
YY = Y.permute(0, 2, 1).bmm(Y) | |
XX_diag_indices = T.arange(0, nX).long().view(1, nX, 1).expand(batch_size, nX, 1) | |
YY_diag_indices = T.arange(0, nY).long().view(1, 1, nY).expand(batch_size, 1, nY) | |
XX_diag = XX.gather(2, XX_diag_indices) | |
YY_diag = YY.gather(1, YY_diag_indices) | |
return XX_diag - XY * 2 + YY_diag | |
### MMD computation | |
def mmd_naive(X, Y, sigma=[1]): | |
''' | |
Computes sample-set-wise MMD. | |
Naive version to verify correctness. | |
Input: | |
X, Y: 3-D tensor (num_sample_sets, feature_dim, num_samples) | |
Output: | |
L: 1-D tensor of size (num_sample_sets,), where | |
L[i] = MMD^2(X[i], Y[i]) | |
''' | |
nX = X.size()[2] | |
nY = Y.size()[2] | |
scale = [-1. / (2 * s) for s in sigma] | |
kXX = 0 | |
kXY = 0 | |
kYY = 0 | |
for i in range(nX): | |
for j in range(nX): | |
for s in scale: | |
kXX += (s * (X[:, :, i] - X[:, :, j]).norm(2, 1) ** 2).exp() | |
for i in range(nX): | |
for j in range(nY): | |
for s in scale: | |
kXY += (s * (X[:, :, i] - Y[:, :, j]).norm(2, 1) ** 2).exp() | |
for i in range(nY): | |
for j in range(nY): | |
for s in scale: | |
kYY += (s * (Y[:, :, i] - Y[:, :, j]).norm(2, 1) ** 2).exp() | |
return kXX / (nX * nX) - kXY / (nX * nY) * 2 + kYY / (nY * nY) | |
def mmd_normdiff(X, Y, sigma=[1]): | |
''' | |
Vectorized version using normdiff_XX() | |
According to cProfile this guy is the fastest among all implementations. | |
Update: slowest on GPU | |
''' | |
nX = X.size()[2] | |
nY = Y.size()[2] | |
scale = -1. / (2 * T.Tensor(sigma)).view(1, 1, 1, -1) | |
kXX = (scale * normdiff_3d(X, X).unsqueeze(3)).exp().sum(3).sum(2).sum(1) | |
kXY = (scale * normdiff_3d(X, Y).unsqueeze(3)).exp().sum(3).sum(2).sum(1) | |
kYY = (scale * normdiff_3d(Y, Y).unsqueeze(3)).exp().sum(3).sum(2).sum(1) | |
return kXX / (nX * nX) - kXY / (nX * nY) * 2 + kYY / (nY * nY) | |
def mmd_oneshot(X, Y, sigma=[1]): | |
''' | |
Vectorized version without repeated calls of normdiff | |
According to cProfile this guy is the slowest except the naive version. | |
''' | |
batch_size = X.size()[0] | |
nX = X.size()[2] | |
nY = Y.size()[2] | |
scale = -1. / (2 * T.Tensor(sigma)).view(1, 1, 1, -1) | |
XX = X.permute(0, 2, 1).bmm(X) | |
XY = X.permute(0, 2, 1).bmm(Y) | |
YY = Y.permute(0, 2, 1).bmm(Y) | |
XX_diag_indices = T.arange(0, nX).long().view(1, nX, 1).expand(batch_size, nX, 1) | |
YY_diag_indices = T.arange(0, nY).long().view(1, 1, nY).expand(batch_size, 1, nY) | |
XX_diag = XX.gather(2, XX_diag_indices) | |
YY_diag = YY.gather(1, YY_diag_indices) | |
norm_XX = XX_diag.view(batch_size, 1, nX) - XX * 2 + XX_diag | |
norm_XY = XX_diag - XY * 2 + YY_diag | |
norm_YY = YY_diag - YY * 2 + YY_diag.view(batch_size, nY, 1) | |
kXX = (scale * norm_XX.unsqueeze(3)).exp().sum(3).sum(2).sum(1) | |
kXY = (scale * norm_XY.unsqueeze(3)).exp().sum(3).sum(2).sum(1) | |
kYY = (scale * norm_YY.unsqueeze(3)).exp().sum(3).sum(2).sum(1) | |
return kXX / (nX * nX) - kXY / (nX * nY) * 2 + kYY / (nY * nY) | |
def mmd_gmmn(X, Y, sigma=[1]): | |
''' | |
The implementation from GMMN. | |
Certainly this guy is faster since it only did one mm. | |
The original code looped over all sigma so I vectorized it. | |
Update: the fastest on GPU. | |
''' | |
batch_size = X.size()[0] | |
nX = X.size()[2] | |
nY = Y.size()[2] | |
s = T.zeros(nX + nY, 1) | |
s[:nX] = 1 | |
s = s / nX - (1 - s) / nY | |
W = s | |
X = T.cat([X, Y], 2) | |
XX = X.permute(0, 2, 1).bmm(X) | |
XX_diag_indices = T.arange(0, nX + nY).long().view(1, nX + nY, 1).expand(batch_size, nX + nY, 1) | |
x = XX.gather(2, XX_diag_indices) | |
prod_mat = XX - 0.5 * x - 0.5 * x.view(batch_size, 1, nX + nY) | |
ww = W.mm(W.permute(1, 0)) | |
sigma = T.Tensor(sigma).view(1, 1, 1, -1) | |
K = T.exp(1. / sigma * prod_mat.unsqueeze(3)) | |
A = ww.unsqueeze(0).unsqueeze(3) * K | |
loss = A.sum(3).sum(2).sum(1) | |
return loss | |
X = T.randn(1, 256, 1000) | |
Y = T.randn(1, 256, 2000) | |
Z = T.randn(1, 256, 2000) | |
X = T.log(1 + T.exp(X)) | |
Y = T.log(1 + T.exp(Y)) | |
Z = Z * X.std() + X.mean() | |
sigma = [2, 5, 10, 20, 40, 80] | |
mmd_gmmn_xz = mmd_gmmn(X, Z, sigma).numpy() | |
#mmd_naive_xz = mmd_naive(X, Z, sigma).numpy() | |
mmd_normdiff_xz = mmd_normdiff(X, Z, sigma).numpy() | |
mmd_xz = mmd_oneshot(X, Z, sigma).numpy() | |
mmd_xy = mmd_oneshot(X, Y, sigma).numpy() | |
X4 = ((X - X.mean()) ** 4).mean() | |
Y4 = ((Y - Y.mean()) ** 4).mean() | |
Z4 = ((Z - Z.mean()) ** 4).mean() | |
print 'MMD between X and Z:', mmd_xz | |
print 'MMD between X and Y:', mmd_xy | |
print 'MMD acceptance:', (4 * 1 / NP.sqrt(1000)) * NP.sqrt(NP.log(1 / 0.05)) | |
print 'Fourth moment diff of X-Z (and the one divided by std):', (X4 - Z4) ** 2, (X4 - Z4) ** 2 / X.std() | |
print 'Fourth moment diff of X-Y (and the one divided by std):', (X4 - Y4) ** 2, (X4 - Y4) ** 2 / X.std() | |
# Test run: | |
# MMD between X and Z: [ 0.00822765] | |
# MMD between X and Y: [ 0.00804275] | |
# MMD acceptance: 0.218933132204 | |
# Fourth moment diff of X-Z (and the one divided by std): 0.0122624075138 0.0235590941978 | |
# Fourth moment diff of X-Y (and the one divided by std): 2.47656061454e-05 4.75808072262e-05 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment