Last active
December 11, 2017 17:42
-
-
Save ShigekiKarita/68340a7d43e2bbf852227c1701748195 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| %matplotlib notebook | |
| from matplotlib import pyplot | |
| import torch | |
| import math | |
| import numpy | |
| from torch.nn import Parameter, Module | |
| def dirichlet_log_pdf(pi, alpha): | |
| numel = torch.lgamma(alpha.sum(0)) + torch.sum(torch.log(pi) * (alpha - 1.0)) | |
| denom = torch.sum(torch.lgamma(alpha)) | |
| return numel - denom | |
| def normal_log_pdf(xs, means, cov): | |
| n_batch, n_dim = xs.size() | |
| n_component = means.size(0) | |
| assert isinstance(cov, float) | |
| xs_ms = xs.unsqueeze(1) - means.unsqueeze(0) | |
| coeff = - n_dim * math.log(2 * math.pi) - math.log(cov) | |
| xms = xs_ms.view(n_batch * n_component, n_dim, 1) | |
| pdfs = coeff + (-0.5 * xms.transpose(1, 2).bmm(xms) / cov) | |
| return pdfs.view(n_batch, n_component) | |
| class GMMSampler(Module): | |
| def __init__(self, n_dim, n_component): | |
| super().__init__() | |
| self.n_dim = n_dim | |
| self.n_component = n_component | |
| self.means = Parameter(torch.randn(n_component, n_dim)) | |
| self.cov_of_mean = 1.0 | |
| self.mean_of_mean = Parameter(torch.zeros(n_dim)) | |
| self.log_prior = Parameter(torch.log(torch.ones(n_component) / n_component)) | |
| def select_k(self, xs, ids, k): | |
| mask = ids == k | |
| mask = mask.expand(self.n_dim, xs.size(0)).transpose(0, 1) | |
| return torch.masked_select(xs, mask).view(-1, self.n_dim) | |
| def joint_prob(self, xs, ids): | |
| px = normal_log_pdf(xs, self.means.data, 1.0).exp() | |
| px_k = px[torch.arange(0, xs.size(0), out=xs.new().long()), ids] | |
| log_pxz = torch.sum(px_k.log() + self.log_prior.data.index_select(0, ids), dim=0) | |
| p_mean = normal_log_pdf(self.means.data, | |
| self.mean_of_mean.data.unsqueeze(0), | |
| self.cov_of_mean) | |
| log_p_mean = torch.sum(p_mean, dim=0) | |
| return (log_pxz + log_p_mean)[0] | |
| def sample(self, xs, n_iter=1): | |
| assert xs.size(1) == self.n_dim | |
| for i in range(n_iter): | |
| pdfs = normal_log_pdf(xs, self.means.data, 1.0).exp() | |
| pdfs /= pdfs.sum(dim=1, keepdim=True) | |
| component_ids = torch.multinomial(pdfs, 1).squeeze(1) | |
| for k in range(self.n_component): | |
| x_k = self.select_k(xs, component_ids, k) | |
| n_k = 0 if x_k.dim() == 0 else x_k.size(0) | |
| x_mean_k = torch.mean(x_k, dim=0) | |
| self.means.data[k] = torch.normal(n_k / (n_k + 1) * x_mean_k, 1.0 / (n_k + 1)) | |
| print(self.joint_prob(xs, component_ids)) | |
| def test_select_k_th(): | |
| n_dim = 2 | |
| gmm = GMMSampler(n_dim, 3) | |
| n_batch = 10 | |
| xs = torch.randn(n_batch, n_dim) | |
| ids = torch.zeros(n_batch) | |
| ids[0] = 1 | |
| ids[2] = 1 | |
| ids[-1] = 1 | |
| ys = gmm.select_k(xs, ids, 1) | |
| assert torch.equal(ys[0], xs[0]) | |
| assert torch.equal(ys[1], xs[2]) | |
| assert torch.equal(ys[2], xs[-1]) | |
| test_select_k_th() | |
| use_cuda = False | |
| n = 10 | |
| x1 = torch.randn(n, 2) + torch.FloatTensor([[0.0, 5.0]]) | |
| x2 = torch.randn(n, 2) + torch.FloatTensor([[5.0, 0.0]]) | |
| x3 = torch.randn(n, 2) + torch.FloatTensor([[0.0, -5.0]]) | |
| for x in [x1, x2, x3]: | |
| pyplot.scatter(x[:, 0].numpy(), x[:, 1].numpy()) | |
| xs = torch.cat([x1, x2, x3], dim=0) | |
| gmm = GMMSampler(2, 3) | |
| if use_cuda: | |
| gmm.cuda() | |
| xs = xs.cuda() | |
| gmm.sample(xs, 10) | |
| for m in gmm.means: | |
| pyplot.scatter(m[0], m[1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment