Skip to content

Instantly share code, notes, and snippets.

@DuaneNielsen
Created January 19, 2020 00:37
Show Gist options
  • Save DuaneNielsen/45ac9f3a70cb6091a38984c35e372e6a to your computer and use it in GitHub Desktop.
Save DuaneNielsen/45ac9f3a70cb6091a38984c35e372e6a to your computer and use it in GitHub Desktop.
Expectation Max algo Pytorch
import torch
from torch import tensor
from torch.distributions.multivariate_normal import MultivariateNormal
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
import matplotlib.patches as mpatches
from math import acos, degrees
from matplotlib.pyplot import cm
"""
EM algo 2D demo, in pytorch
"""
def plot(x, h, n_stdev_bands=3):
variance, transform = torch.symeig(h.covariance_matrix, eigenvectors=True)
stdev = variance.sqrt()
ax = plt.subplot(111, aspect='equal')
max_x, min_x, max_y, min_y = 0.0, 0.0, 0.0, 0.0
legend = []
cmap = cm.rainbow(torch.linspace(0, 1, h.mean.size(0)))
for mean, stdev, transform, color in zip(h.mean, stdev, transform, cmap):
legend += [mpatches.Patch(color=color, label=f'mu {mean[0].item():.2f}, {mean[1].item():.2f} '
f'sigma {stdev[0].item():.2f} {stdev[1].item():.2f}')]
for j in range(1, n_stdev_bands+1):
ell = Ellipse(xy=(mean[0], mean[1]),
width=stdev[0] * j * 2, height=stdev[1] * j * 2,
angle=degrees(acos(transform[0, 0].item())),
alpha=1.0,
edgecolor=color,
fc='none')
ax.add_artist(ell)
max_x = max(max_x, mean[0] + stdev[0] * j * 2)
max_y = max(max_y, mean[1] + stdev[1] * j * 2)
min_x = min(min_x, mean[0] - stdev[0] * j * 2)
min_y = min(min_y, mean[1] - stdev[1] * j * 2)
plt.scatter(x[:, 0], x[:, 1])
ax.set_xlim(min_x, max_x)
ax.set_ylim(min_y, max_y)
plt.gca().set_aspect('equal', adjustable='box')
plt.legend(handles=legend)
plt.show()
def sample(n, mu, c):
z = torch.randn(dims, n)
return (mu.view(-1, 1) - c.matmul(z)).T
if __name__ == '__main__':
n = 600 # must be even number
k = 3
dims = 2
eps = torch.finfo(torch.float32).eps
x1 = sample(n//3, tensor([3.0, 0.0]), torch.eye(2) * 1/4)
x2 = sample(n//3, tensor([0.0, 1.0]), torch.eye(2) * 1/4)
x3 = sample(n//3, tensor([2.0, 4.0]), torch.eye(2) * 1/4)
x = torch.cat((x1, x2, x3)).unsqueeze(1)
mu = torch.randn(k, dims)
covar = torch.stack(k * [torch.eye(dims)])
prior = torch.tensor([1/3, 1/3, 1/3]).log()
converged = False
i = 0
h = None
while not converged:
prev_mu = mu.clone()
prev_covar = covar.clone()
h = MultivariateNormal(mu, covar)
llhood = h.log_prob(x)
weighted_llhood = llhood + prior
log_sum_lhood = torch.logsumexp(weighted_llhood, dim=1, keepdim=True)
log_posterior = weighted_llhood - log_sum_lhood
if i % 5 == 0:
plot(x.squeeze(), h)
pi = torch.exp(log_posterior.reshape(n, k, 1))
pi = pi * (1- k * eps) + eps
mu = torch.sum(x * pi, dim=0) / torch.sum(pi, dim=0)
delta = pi * (x - mu)
covar = torch.matmul(delta.permute(1, 2, 0), delta.permute(1, 0, 2)) / torch.sum(pi, dim=0).reshape(3, 1, 1)
converged = torch.allclose(mu, prev_mu) and torch.allclose(covar, prev_covar)
i += 1
plot(x.squeeze(), h)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment