Last active
July 3, 2018 06:18
-
-
Save cemoody/3935ac95bf5809e4c9f0b683378ab85d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
embeddings_mu = nn.Embedding(n_words, n_dim) | |
embeddings_lv = nn.Embedding(n_words, n_dim) | |
... | |
vector_mu = embeddings_mu(c_index) | |
vector_lv = embeddings_lv(c_index) | |
def normal(mu, lv): | |
random = torch.FloatTensor(std.size()).normal_() | |
return mu + random * torch.exp(0.5 * lv) | |
c_vector = normal(vector_mu, vector_lv) | |
embeddings = nn.Embedding(n_words, n_dim) | |
... | |
c_vector = embeddings(c_index) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.autograd | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
from torch import nn | |
import numpy as np | |
def pairwise_l2(data): | |
product = torch.mm(data, data.t()) | |
# get the diagonal elements | |
diag = product.diag().unsqueeze(0) | |
diag = diag.expand_as(product) | |
# compute the distance matrix | |
l2 = diag + diag.t() - 2 * product | |
return l2 | |
class TSNE(nn.Module): | |
def __init__(self, n_points, n_dim): | |
super(TSNE, self).__init__() | |
self.logits = nn.Embedding(n_points, n_dim) | |
def forward(self, pij, i, j): | |
# Compute squared pairwise distances | |
dkl2 = pairwise_l2(self.logits.weight) | |
# Compute partition function | |
n_diagonal = dkl2.size()[0] | |
part = (1 + dkl2).pow(-1.0).sum() - n_diagonal | |
# Compute the numerator | |
xi = self.logits(i) | |
xj = self.logits(j) | |
prob_ij = ((1. + (xi - xj)**2.0).sum(1)).pow(-1.0).squeeze() | |
# qij is the probability is the probability of picking the (i, j) | |
# relationship out of N^2 other possible pairs in the 2D embedding. | |
qij = num / part.expand_as(num) | |
# Compute KLD between pij & qij | |
loss_kld = pij * (torch.log(pij) - torch.log(qij)) | |
return loss_kld.sum() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.autograd | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
from torch import nn | |
import numpy as np | |
def pairwise_l2(data): | |
product = torch.mm(data, data.t()) | |
# get the diagonal elements | |
diag = product.diag().unsqueeze(0) | |
diag = diag.expand_as(product) | |
# compute the distance matrix | |
l2 = diag + diag.t() - 2 * product | |
return l2 | |
def reparametrize(self, mu, logvar): | |
std = logvar.mul(0.5).exp_() | |
eps = torch.cuda.FloatTensor(std.size()).normal_() | |
eps = Variable(eps) | |
z = eps.mul(std).add_(mu) | |
kld = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) | |
kld = torch.sum(kld).mul_(-0.5) | |
return z, kld | |
class TSNE(nn.Module): | |
def __init__(self, n_points, n_dim): | |
super(TSNE, self).__init__() | |
self.logits = nn.Embedding(n_points, n_dim) | |
def forward(self, pij, i, j): | |
# Compute squared pairwise distances | |
dkl2 = pairwise_l2(self.logits.weight) | |
# Compute partition function | |
n_diagonal = dkl2.size()[0] | |
part = (1 + dkl2).pow(-1.0).sum() - n_diagonal | |
# Compute the numerator | |
xi = self.logits(i) | |
xj = self.logits(j) | |
prob_ij = ((1. + (xi - xj)**2.0).sum(1)).pow(-1.0).squeeze() | |
# qij is the probability is the probability of picking the (i, j) | |
# relationship out of N^2 other possible pairs in the 2D embedding. | |
qij = num / part.expand_as(num) | |
# Compute KLD between pij & qij | |
loss_kld = pij * (torch.log(pij) - torch.log(qij)) | |
return loss_kld.sum() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment