Last active
January 4, 2019 14:46
-
-
Save talesa/7eb77db186ff76afa7bcde3416a592c3 to your computer and use it in GitHub Desktop.
GMM model sampling and log_pdf in PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.distributions as dist | |
# Generates a sample from the generative model | |
def gmm_generate_data(K, N, batch_size=10, | |
upsilon=torch.Tensor([2.]).to('cpu'), | |
mu_0=torch.Tensor([0., 0.]).to('cuda'), | |
sigma2_0=torch.Tensor([2., 2.]).to('cuda')): | |
# Sample parameters to distribution over mixture components | |
beta = torch.Tensor([1.]).to('cpu') | |
phi_dist = dist.Dirichlet(beta * torch.ones(K).to('cpu')) | |
# batch_size x K | |
phi = phi_dist.sample((batch_size,)).to('cuda') | |
# TODO make sure the scales are the scales not scale squared | |
# batch_size x K x 2 | |
sigma2_dist = dist.Gamma(upsilon, sigma2_0.to('cpu')) | |
sigma2 = sigma2_dist.sample((batch_size, K)).to('cuda').reciprocal() | |
# Sample mixture component means and variances | |
# batch_size x K x 2 | |
mu_dist = dist.Normal(mu_0, sigma2_0) | |
mu = mu_dist.sample((batch_size, K)) | |
# Sample correlations | |
# batch_size x K | |
rho_dist = dist.Uniform(low=torch.Tensor([0.01]).to('cuda'), | |
high=torch.Tensor([0.99]).to('cuda')) | |
rho = rho_dist.sample((batch_size, K)).squeeze(2) | |
# batch_size | |
z_dist = dist.Categorical(phi) | |
# batch_size x N | |
z = z_dist.sample((N,)).t() | |
# Assembling lower triagonal matrix | |
# a = sigma_0 | |
# b = rho * sigma_1 | |
# c = sigma_1 * sqrt(1 - rho^2) | |
# batch_size x K x 2 x 2 | |
a = sigma2[:,:,0].sqrt() | |
b = rho * sigma2[:,:,0].sqrt() * sigma2[:,:,1].sqrt() | |
c = sigma2[:,:,1].sqrt() * (1 - rho.pow(2)).sqrt() | |
zeros = torch.zeros_like(a) | |
scale_tril = torch.stack((torch.stack((a, zeros), dim=2), | |
torch.stack((b, c ), dim=2)), | |
dim=-1) | |
x_loc = mu.gather(dim=1, index=z.unsqueeze(dim=2).expand(-1,-1,2)) | |
x_scale_tril = scale_tril.gather(dim=1, index=z.view(batch_size,N,1,1).expand(-1,-1,2,2)) | |
x_dist = dist.MultivariateNormal(loc=x_loc, scale_tril=x_scale_tril) | |
x = x_dist.sample() | |
return {'phi': phi, | |
'mu': mu, | |
'sigma2': sigma2, | |
'rho': rho, | |
'z': z, | |
'x': x} | |
# Evaluate log pdf of samples under the prior | |
def gmm_log_pdf(phi, mu, sigma2, rho, z, x, | |
upsilon=2, | |
mu_0=torch.Tensor([0., 0.]).to('cuda'), | |
sigma2_0=torch.Tensor([2., 2.]).to('cuda')): | |
batch_size = phi.shape[0] | |
K = phi.shape[1] | |
N = x.shape[1] | |
log_prob = 0 | |
beta = 1. | |
log_prob += dist.Dirichlet(beta * torch.ones(K).to('cuda')).log_prob(phi) | |
# Sample mixture component means and variances | |
# batch_size x K x 2 | |
log_prob += dist.Normal(mu_0, sigma2_0).log_prob(mu).sum(dim=2).sum(dim=1) | |
# batch_size x K x 2 | |
log_prob += dist.Gamma(upsilon, sigma2_0)\ | |
.log_prob(sigma2.reciprocal())\ | |
.sum(dim=2).sum(dim=1) | |
log_prob += dist.Uniform(low=torch.Tensor([0.01]).to('cuda'), | |
high=torch.Tensor([0.99]).to('cuda'))\ | |
.log_prob(rho).sum(dim=1) | |
log_prob += dist.Categorical(phi.unsqueeze(1).expand(batch_size, N, K))\ | |
.log_prob(z).sum(dim=1) | |
# Assembling lower triagonal matrix | |
# a = sigma_0 | |
# b = rho * sigma_1 | |
# c = sigma_1 * sqrt(1 - rho^2) | |
# batch_size x K x 2 x 2 | |
a = sigma2[:,:,0].sqrt() | |
b = rho * sigma2[:,:,0].sqrt() * sigma2[:,:,1].sqrt() | |
c = sigma2[:,:,1].sqrt() * (1 - rho.pow(2)).sqrt() | |
zeros = torch.zeros_like(a) | |
scale_tril = torch.stack((torch.stack((a, zeros), dim=2), | |
torch.stack((b, c ), dim=2)), | |
dim=-1) | |
x_loc = mu.gather(dim=1, index=z.unsqueeze(dim=2).expand(-1,-1,2)) | |
x_scale_tril = scale_tril.gather(dim=1, index=z.view(batch_size,N,1,1).expand(-1,-1,2,2)) | |
log_prob += dist.MultivariateNormal(loc=x_loc, scale_tril=x_scale_tril)\ | |
.log_prob(x).sum(dim=1) | |
return log_prob |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment