Last active
May 6, 2020 16:54
-
-
Save dasayan05/aca3352cd00058511e8372912ff685d8 to your computer and use it in GitHub Desktop.
Example usage of Pyro for MoG
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pyro, torch, numpy as np | |
import pyro.distributions as dist | |
import pyro.optim as optim | |
import pyro.infer as infer | |
import matplotlib.pyplot as plt | |
plt.style.use('ggplot') | |
from scipy.stats import norm | |
plt.ioff() | |
def getdata(N, mean1=2.0, mean2=-1.0, std1=0.5, std2=0.5): | |
D1 = np.random.randn(N//2,) * std1 + mean1 | |
D2 = np.random.randn(N//2,) * std2 + mean2 | |
D = np.concatenate([D1, D2], 0) | |
np.random.shuffle(D) | |
return torch.from_numpy(D.astype(np.float32)) | |
@infer.config_enumerate(default='parallel') | |
def model(data): | |
f = pyro.param("f", torch.tensor([0.5]), constraint=dist.constraints.unit_interval) | |
means = pyro.param("M", torch.tensor([1.5, 3.])) | |
stds = pyro.param("S", torch.tensor([0.5, 0.5]), constraint=dist.constraints.positive) | |
with pyro.plate("data", len(data)): | |
F = dist.Bernoulli(f) | |
c = pyro.sample("c", F) | |
c = c.type(torch.LongTensor) | |
X = dist.Normal(means[c], stds[c]) | |
x = pyro.sample("x", X, obs=data) | |
@infer.config_enumerate(default='parallel') | |
def guide(data): | |
pc = pyro.param("pc", torch.rand(len(data)), constraint=dist.constraints.unit_interval) | |
with pyro.plate("data", len(data)): | |
C = dist.Bernoulli(pc) | |
c = pyro.sample("c", C) | |
data = getdata(200) | |
# breakpoint() | |
pyro.clear_param_store() | |
optim = pyro.optim.Adam({}) | |
svi = pyro.infer.SVI(model, guide, optim, infer.TraceEnum_ELBO()) | |
fig, ax = plt.subplots(1, 3, figsize=(15, 4)) | |
losses = [] | |
T = 10000 | |
for t in range(T): | |
losses.append(svi.step(data)) | |
if t % 50 == 0: | |
ax[0].plot(losses, color='m') | |
ax[0].scatter(len(losses), losses[-1], color='m') | |
ax[0].annotate(f'{losses[-1]:.2f}', (len(losses)+50, losses[-1]+50)) | |
ax[0].set_xlabel("epochs") | |
ax[0].set_ylabel("ELBO") | |
ax[0].set_xlim([0, T]) | |
ax[0].set_ylim([0, 2500]) | |
pc = pyro.param("pc") | |
han = ax[1].scatter(data.detach().numpy(), pc.detach().numpy(), c=pc.detach().numpy()) | |
ax[1].set_ylim([-0.03, 1.03]) | |
ax[1].set_xlabel("Data axis") | |
ax[1].set_ylabel(r"Posterior ($\lambda_i$)") | |
means, stds = pyro.param("M"), pyro.param("S") | |
coinbias = pyro.param("f").detach().item() | |
mean1, mean2 = means.detach().numpy() | |
std1, std2 = stds.detach().numpy() | |
xmin, xmax = data.min(), data.max() | |
xs = np.linspace(xmin-2, xmax+2, 150) | |
y1 = norm.pdf(xs, mean1, std1) | |
y2 = norm.pdf(xs, mean2, std2) | |
p1, = ax[2].plot(xs, y1, color='r') | |
p2, = ax[2].plot(xs, y2, color='b') | |
ax[2].axvline(mean1, linestyle='--', color='r') | |
ax[2].axvline(mean2, linestyle='--', color='b') | |
cb = ax[2].axhline(coinbias, linestyle='--', color='black') | |
ax[2].set_xlim([xmin-2, xmax+2]) | |
ax[2].set_ylim([-0.02, 0.8]) | |
ax[2].legend([p1, p2, cb], ['Gaussian 1', 'Gaussian 2', 'Coin Bias'], loc=2) | |
ax[2].scatter(data.numpy(), np.zeros_like(data.numpy()), marker='x', c=pc.detach().numpy()) | |
ax[2].set_xlabel("Data axis") | |
ax[2].set_ylabel("Model densities") | |
# plt.draw() | |
# plt.savefig(f"tmp/{t}.png", bbox_inches='tight', inches=0) | |
plt.pause(0.01) | |
ax[0].cla(); ax[1].cla(); ax[2].cla() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment