Created
January 15, 2020 05:53
-
-
Save DuaneNielsen/8c5bde8d35a46d60640d0579d913dcff to your computer and use it in GitHub Desktop.
EM algorithm - 1D
Uses logprob bayes update for numerical stability
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.distributions.normal import Normal | |
import matplotlib.pyplot as plt | |
""" | |
EM algo demo, in pytorch | |
""" | |
n = 40 # must be even number | |
k = 2 | |
eps = torch.finfo(torch.float32).eps | |
def plot(x, posterior): | |
fig, ax = plt.subplots(nrows=2, ncols=1) | |
ax[0].title.set_text('p (H | x)') | |
ax[0].bar(x.squeeze(), posterior[:, 0].squeeze()) | |
ax[0].bar(x.squeeze(), posterior[:, 1].squeeze(), bottom=posterior[:, 0]) | |
x_axis = torch.linspace(ax[0].get_xlim()[0], ax[0].get_xlim()[1], 50) | |
ax[1].title.set_text('H') | |
ax[1].plot(x_axis, torch.exp(h.log_prob(x_axis.expand(k, 50).T)), label=['h1', 'h2']) | |
fig.tight_layout() | |
plt.show() | |
if __name__ == '__main__': | |
d1 = Normal(-2.0, 0.5) | |
d2 = Normal(2.0, 0.5) | |
x1 = d1.sample((n//2,)) | |
x2 = d2.sample((n//2,)) | |
x = torch.cat((x1, x2)).view(-1, 1) | |
mu = torch.tensor([-3.0, -2.5]) | |
stdev = torch.tensor([0.2, 0.2]) | |
prior = torch.tensor([0.5, 0.5]) | |
converged = False | |
i = 0 | |
while not converged: | |
prev_mu = mu.clone() | |
prev_stdev = stdev.clone() | |
h = Normal(mu, stdev) | |
llhood = h.log_prob(x) | |
weighted_llhood = llhood + prior.log() | |
log_sum_lhood = torch.logsumexp(weighted_llhood, dim=1, keepdim=True) | |
log_posterior = weighted_llhood - log_sum_lhood | |
posterior = torch.exp(log_posterior) | |
if i % 3 == 0: | |
plot(x, posterior) | |
mu = torch.sum(posterior * x, dim=0) / (torch.sum(posterior, dim=0) + eps) | |
variance = torch.sum(posterior * (x - mu) ** 2, dim=0) / (torch.sum(posterior, dim=0) + eps) | |
stdev = variance.sqrt() | |
prior = posterior.mean(0) | |
converged = torch.allclose(mu, prev_mu) and torch.allclose(stdev, prev_stdev) | |
i += 1 | |
plot(x, posterior) | |
print(i , mu, stdev, posterior.mean(0)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Really great example to learn from. Had this idea to "just use PyTorch" and found your post. Thanks for sharing!