Skip to content

Instantly share code, notes, and snippets.

@ShigekiKarita
Last active December 11, 2017 17:42
Show Gist options
  • Select an option

  • Save ShigekiKarita/68340a7d43e2bbf852227c1701748195 to your computer and use it in GitHub Desktop.

Select an option

Save ShigekiKarita/68340a7d43e2bbf852227c1701748195 to your computer and use it in GitHub Desktop.
%matplotlib notebook
from matplotlib import pyplot
import torch
import math
import numpy
from torch.nn import Parameter, Module
def dirichlet_log_pdf(pi, alpha):
numel = torch.lgamma(alpha.sum(0)) + torch.sum(torch.log(pi) * (alpha - 1.0))
denom = torch.sum(torch.lgamma(alpha))
return numel - denom
def normal_log_pdf(xs, means, cov):
n_batch, n_dim = xs.size()
n_component = means.size(0)
assert isinstance(cov, float)
xs_ms = xs.unsqueeze(1) - means.unsqueeze(0)
coeff = - n_dim * math.log(2 * math.pi) - math.log(cov)
xms = xs_ms.view(n_batch * n_component, n_dim, 1)
pdfs = coeff + (-0.5 * xms.transpose(1, 2).bmm(xms) / cov)
return pdfs.view(n_batch, n_component)
class GMMSampler(Module):
def __init__(self, n_dim, n_component):
super().__init__()
self.n_dim = n_dim
self.n_component = n_component
self.means = Parameter(torch.randn(n_component, n_dim))
self.cov_of_mean = 1.0
self.mean_of_mean = Parameter(torch.zeros(n_dim))
self.log_prior = Parameter(torch.log(torch.ones(n_component) / n_component))
def select_k(self, xs, ids, k):
mask = ids == k
mask = mask.expand(self.n_dim, xs.size(0)).transpose(0, 1)
return torch.masked_select(xs, mask).view(-1, self.n_dim)
def joint_prob(self, xs, ids):
px = normal_log_pdf(xs, self.means.data, 1.0).exp()
px_k = px[torch.arange(0, xs.size(0), out=xs.new().long()), ids]
log_pxz = torch.sum(px_k.log() + self.log_prior.data.index_select(0, ids), dim=0)
p_mean = normal_log_pdf(self.means.data,
self.mean_of_mean.data.unsqueeze(0),
self.cov_of_mean)
log_p_mean = torch.sum(p_mean, dim=0)
return (log_pxz + log_p_mean)[0]
def sample(self, xs, n_iter=1):
assert xs.size(1) == self.n_dim
for i in range(n_iter):
pdfs = normal_log_pdf(xs, self.means.data, 1.0).exp()
pdfs /= pdfs.sum(dim=1, keepdim=True)
component_ids = torch.multinomial(pdfs, 1).squeeze(1)
for k in range(self.n_component):
x_k = self.select_k(xs, component_ids, k)
n_k = 0 if x_k.dim() == 0 else x_k.size(0)
x_mean_k = torch.mean(x_k, dim=0)
self.means.data[k] = torch.normal(n_k / (n_k + 1) * x_mean_k, 1.0 / (n_k + 1))
print(self.joint_prob(xs, component_ids))
def test_select_k_th():
n_dim = 2
gmm = GMMSampler(n_dim, 3)
n_batch = 10
xs = torch.randn(n_batch, n_dim)
ids = torch.zeros(n_batch)
ids[0] = 1
ids[2] = 1
ids[-1] = 1
ys = gmm.select_k(xs, ids, 1)
assert torch.equal(ys[0], xs[0])
assert torch.equal(ys[1], xs[2])
assert torch.equal(ys[2], xs[-1])
test_select_k_th()
use_cuda = False
n = 10
x1 = torch.randn(n, 2) + torch.FloatTensor([[0.0, 5.0]])
x2 = torch.randn(n, 2) + torch.FloatTensor([[5.0, 0.0]])
x3 = torch.randn(n, 2) + torch.FloatTensor([[0.0, -5.0]])
for x in [x1, x2, x3]:
pyplot.scatter(x[:, 0].numpy(), x[:, 1].numpy())
xs = torch.cat([x1, x2, x3], dim=0)
gmm = GMMSampler(2, 3)
if use_cuda:
gmm.cuda()
xs = xs.cuda()
gmm.sample(xs, 10)
for m in gmm.means:
pyplot.scatter(m[0], m[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment