Skip to content

Instantly share code, notes, and snippets.

@talesa
Last active January 4, 2019 14:46
Show Gist options
  • Save talesa/7eb77db186ff76afa7bcde3416a592c3 to your computer and use it in GitHub Desktop.
Save talesa/7eb77db186ff76afa7bcde3416a592c3 to your computer and use it in GitHub Desktop.
GMM model sampling and log_pdf in PyTorch
import torch
import torch.distributions as dist
# Generates a sample from the generative model
def gmm_generate_data(K, N, batch_size=10,
upsilon=torch.Tensor([2.]).to('cpu'),
mu_0=torch.Tensor([0., 0.]).to('cuda'),
sigma2_0=torch.Tensor([2., 2.]).to('cuda')):
# Sample parameters to distribution over mixture components
beta = torch.Tensor([1.]).to('cpu')
phi_dist = dist.Dirichlet(beta * torch.ones(K).to('cpu'))
# batch_size x K
phi = phi_dist.sample((batch_size,)).to('cuda')
# TODO make sure the scales are the scales not scale squared
# batch_size x K x 2
sigma2_dist = dist.Gamma(upsilon, sigma2_0.to('cpu'))
sigma2 = sigma2_dist.sample((batch_size, K)).to('cuda').reciprocal()
# Sample mixture component means and variances
# batch_size x K x 2
mu_dist = dist.Normal(mu_0, sigma2_0)
mu = mu_dist.sample((batch_size, K))
# Sample correlations
# batch_size x K
rho_dist = dist.Uniform(low=torch.Tensor([0.01]).to('cuda'),
high=torch.Tensor([0.99]).to('cuda'))
rho = rho_dist.sample((batch_size, K)).squeeze(2)
# batch_size
z_dist = dist.Categorical(phi)
# batch_size x N
z = z_dist.sample((N,)).t()
# Assembling lower triagonal matrix
# a = sigma_0
# b = rho * sigma_1
# c = sigma_1 * sqrt(1 - rho^2)
# batch_size x K x 2 x 2
a = sigma2[:,:,0].sqrt()
b = rho * sigma2[:,:,0].sqrt() * sigma2[:,:,1].sqrt()
c = sigma2[:,:,1].sqrt() * (1 - rho.pow(2)).sqrt()
zeros = torch.zeros_like(a)
scale_tril = torch.stack((torch.stack((a, zeros), dim=2),
torch.stack((b, c ), dim=2)),
dim=-1)
x_loc = mu.gather(dim=1, index=z.unsqueeze(dim=2).expand(-1,-1,2))
x_scale_tril = scale_tril.gather(dim=1, index=z.view(batch_size,N,1,1).expand(-1,-1,2,2))
x_dist = dist.MultivariateNormal(loc=x_loc, scale_tril=x_scale_tril)
x = x_dist.sample()
return {'phi': phi,
'mu': mu,
'sigma2': sigma2,
'rho': rho,
'z': z,
'x': x}
# Evaluate log pdf of samples under the prior
def gmm_log_pdf(phi, mu, sigma2, rho, z, x,
upsilon=2,
mu_0=torch.Tensor([0., 0.]).to('cuda'),
sigma2_0=torch.Tensor([2., 2.]).to('cuda')):
batch_size = phi.shape[0]
K = phi.shape[1]
N = x.shape[1]
log_prob = 0
beta = 1.
log_prob += dist.Dirichlet(beta * torch.ones(K).to('cuda')).log_prob(phi)
# Sample mixture component means and variances
# batch_size x K x 2
log_prob += dist.Normal(mu_0, sigma2_0).log_prob(mu).sum(dim=2).sum(dim=1)
# batch_size x K x 2
log_prob += dist.Gamma(upsilon, sigma2_0)\
.log_prob(sigma2.reciprocal())\
.sum(dim=2).sum(dim=1)
log_prob += dist.Uniform(low=torch.Tensor([0.01]).to('cuda'),
high=torch.Tensor([0.99]).to('cuda'))\
.log_prob(rho).sum(dim=1)
log_prob += dist.Categorical(phi.unsqueeze(1).expand(batch_size, N, K))\
.log_prob(z).sum(dim=1)
# Assembling lower triagonal matrix
# a = sigma_0
# b = rho * sigma_1
# c = sigma_1 * sqrt(1 - rho^2)
# batch_size x K x 2 x 2
a = sigma2[:,:,0].sqrt()
b = rho * sigma2[:,:,0].sqrt() * sigma2[:,:,1].sqrt()
c = sigma2[:,:,1].sqrt() * (1 - rho.pow(2)).sqrt()
zeros = torch.zeros_like(a)
scale_tril = torch.stack((torch.stack((a, zeros), dim=2),
torch.stack((b, c ), dim=2)),
dim=-1)
x_loc = mu.gather(dim=1, index=z.unsqueeze(dim=2).expand(-1,-1,2))
x_scale_tril = scale_tril.gather(dim=1, index=z.view(batch_size,N,1,1).expand(-1,-1,2,2))
log_prob += dist.MultivariateNormal(loc=x_loc, scale_tril=x_scale_tril)\
.log_prob(x).sum(dim=1)
return log_prob
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment