Skip to content

Instantly share code, notes, and snippets.

@dasayan05
Last active May 6, 2020 16:54
Show Gist options
  • Save dasayan05/aca3352cd00058511e8372912ff685d8 to your computer and use it in GitHub Desktop.
Save dasayan05/aca3352cd00058511e8372912ff685d8 to your computer and use it in GitHub Desktop.
Example usage of Pyro for MoG
import pyro, torch, numpy as np
import pyro.distributions as dist
import pyro.optim as optim
import pyro.infer as infer
import matplotlib.pyplot as plt
plt.style.use('ggplot')
from scipy.stats import norm
plt.ioff()
def getdata(N, mean1=2.0, mean2=-1.0, std1=0.5, std2=0.5):
D1 = np.random.randn(N//2,) * std1 + mean1
D2 = np.random.randn(N//2,) * std2 + mean2
D = np.concatenate([D1, D2], 0)
np.random.shuffle(D)
return torch.from_numpy(D.astype(np.float32))
@infer.config_enumerate(default='parallel')
def model(data):
f = pyro.param("f", torch.tensor([0.5]), constraint=dist.constraints.unit_interval)
means = pyro.param("M", torch.tensor([1.5, 3.]))
stds = pyro.param("S", torch.tensor([0.5, 0.5]), constraint=dist.constraints.positive)
with pyro.plate("data", len(data)):
F = dist.Bernoulli(f)
c = pyro.sample("c", F)
c = c.type(torch.LongTensor)
X = dist.Normal(means[c], stds[c])
x = pyro.sample("x", X, obs=data)
@infer.config_enumerate(default='parallel')
def guide(data):
pc = pyro.param("pc", torch.rand(len(data)), constraint=dist.constraints.unit_interval)
with pyro.plate("data", len(data)):
C = dist.Bernoulli(pc)
c = pyro.sample("c", C)
data = getdata(200)
# breakpoint()
pyro.clear_param_store()
optim = pyro.optim.Adam({})
svi = pyro.infer.SVI(model, guide, optim, infer.TraceEnum_ELBO())
fig, ax = plt.subplots(1, 3, figsize=(15, 4))
losses = []
T = 10000
for t in range(T):
losses.append(svi.step(data))
if t % 50 == 0:
ax[0].plot(losses, color='m')
ax[0].scatter(len(losses), losses[-1], color='m')
ax[0].annotate(f'{losses[-1]:.2f}', (len(losses)+50, losses[-1]+50))
ax[0].set_xlabel("epochs")
ax[0].set_ylabel("ELBO")
ax[0].set_xlim([0, T])
ax[0].set_ylim([0, 2500])
pc = pyro.param("pc")
han = ax[1].scatter(data.detach().numpy(), pc.detach().numpy(), c=pc.detach().numpy())
ax[1].set_ylim([-0.03, 1.03])
ax[1].set_xlabel("Data axis")
ax[1].set_ylabel(r"Posterior ($\lambda_i$)")
means, stds = pyro.param("M"), pyro.param("S")
coinbias = pyro.param("f").detach().item()
mean1, mean2 = means.detach().numpy()
std1, std2 = stds.detach().numpy()
xmin, xmax = data.min(), data.max()
xs = np.linspace(xmin-2, xmax+2, 150)
y1 = norm.pdf(xs, mean1, std1)
y2 = norm.pdf(xs, mean2, std2)
p1, = ax[2].plot(xs, y1, color='r')
p2, = ax[2].plot(xs, y2, color='b')
ax[2].axvline(mean1, linestyle='--', color='r')
ax[2].axvline(mean2, linestyle='--', color='b')
cb = ax[2].axhline(coinbias, linestyle='--', color='black')
ax[2].set_xlim([xmin-2, xmax+2])
ax[2].set_ylim([-0.02, 0.8])
ax[2].legend([p1, p2, cb], ['Gaussian 1', 'Gaussian 2', 'Coin Bias'], loc=2)
ax[2].scatter(data.numpy(), np.zeros_like(data.numpy()), marker='x', c=pc.detach().numpy())
ax[2].set_xlabel("Data axis")
ax[2].set_ylabel("Model densities")
# plt.draw()
# plt.savefig(f"tmp/{t}.png", bbox_inches='tight', inches=0)
plt.pause(0.01)
ax[0].cla(); ax[1].cla(); ax[2].cla()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment