Last active
February 24, 2020 13:14
-
-
Save AshNguyen/b883455f51a60f3a94bff69069eec78c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#Installing relevant packages, if you can't still import in the next cell\n", | |
"#please restart the kernel\n", | |
"#Adapted from Pollock, J. (2019).\n", | |
"%%bash \n", | |
"\n", | |
"pip install pyro-ppl\n", | |
"pip install pystan\n", | |
"pip install torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#Importing packages\n", | |
"import numpy as np\n", | |
"import scipy.stats as st\n", | |
"import matplotlib.pyplot as plt\n", | |
"import pystan\n", | |
"import torch\n", | |
"import pyro\n", | |
"import pyro.distributions as dist\n", | |
"import pyro.contrib.autoguide as autoguide" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#Generate the data, set seeds for reproducability\n", | |
"pyro.set_rng_seed(25)\n", | |
"\n", | |
"#Number of data point\n", | |
"N = 2500\n", | |
"#Number of feature in the linear regression model \n", | |
"#(y = bx + a + noise, b is the coefficient vector of length P)\n", | |
"P = 8\n", | |
"\n", | |
"#Sample the true values of the parameters\n", | |
"alpha_true = dist.Normal(42.0, 10.0).sample()\n", | |
"beta_true = dist.Normal(torch.zeros(P), 10.0).sample()\n", | |
"sigma_true = dist.Exponential(1.0).sample()\n", | |
"\n", | |
"#Data generation: first the noise is sampled from a normal, then the linear\n", | |
"#regression y is constructed from the coefficients\n", | |
"eps = dist.Normal(0.0, sigma_true).sample([N])\n", | |
"x = torch.randn(N, P)\n", | |
"y = alpha_true + x @ beta_true + eps" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"TRUE ALPHA:\n", | |
" 45.366905\n", | |
"TRUE BETAs:\n", | |
" 1.288094\n", | |
" 2.3446236\n", | |
" 2.3033304\n", | |
" -11.228563\n", | |
" -1.8632829\n", | |
" 22.082014\n", | |
" -6.3799706\n", | |
" 4.6165724\n", | |
"TRUE SIGMA:\n", | |
" 0.17085178\n" | |
] | |
} | |
], | |
"source": [ | |
"#True parameter values\n", | |
"print('TRUE ALPHA:\\n ', alpha_true.numpy())\n", | |
"print('TRUE BETAs:')\n", | |
"for _ in beta_true.numpy():\n", | |
" print(\" \"+str(_))\n", | |
"print('TRUE SIGMA:\\n ', sigma_true.numpy())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_9668ba3a03de4e2c24a1b04fa6c99bd7 NOW.\n", | |
"/Users/ash/anaconda3/lib/python3.6/site-packages/Cython/Compiler/Main.py:367: FutureWarning: Cython directive 'language_level' not set, using 2 for now (Py2). This will change in a later release! File: /var/folders/kg/zv3bhpmx19s5f8hsk7h4_wnc0000gn/T/tmptv2dgwc7/stanfit4anon_model_9668ba3a03de4e2c24a1b04fa6c99bd7_4726917534686056046.pyx\n", | |
" tree = Parsing.p_module(s, pxd, full_module_name)\n" | |
] | |
} | |
], | |
"source": [ | |
"#MCMC: NUTS\n", | |
"\n", | |
"model = \"\"\"\n", | |
"data {\n", | |
" int<lower = 0> N;\n", | |
" int<lower = 0> P;\n", | |
" matrix[N, P] x;\n", | |
" vector[N] y;\n", | |
"}\n", | |
"\n", | |
"parameters {\n", | |
" real alpha;\n", | |
" vector[P] beta;\n", | |
" real<lower = 0.0> sigma;\n", | |
"}\n", | |
"\n", | |
"model {\n", | |
" alpha ~ normal(0.0, 100.0);\n", | |
" beta ~ normal(0.0, 10.0);\n", | |
" sigma ~ normal(0.0, 10.0);\n", | |
" y ~ normal(alpha + x * beta, sigma);\n", | |
"}\n", | |
"\"\"\"\n", | |
"stan_model = pystan.StanModel(model_code=model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data = {\n", | |
" 'N': N,\n", | |
" 'P': P, \n", | |
" 'x': x.numpy(), \n", | |
" 'y': y.numpy()\n", | |
"}\n", | |
"\n", | |
"stan_results = stan_model.sampling(data=data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Inference for Stan model: anon_model_9668ba3a03de4e2c24a1b04fa6c99bd7.\n", | |
"4 chains, each with iter=2000; warmup=1000; thin=1; \n", | |
"post-warmup draws per chain=1000, total post-warmup draws=4000.\n", | |
"\n", | |
" mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat\n", | |
"alpha 45.37 3.7e-5 3.3e-3 45.36 45.37 45.37 45.37 45.37 7598 1.0\n", | |
"beta[1] 1.29 4.0e-5 3.4e-3 1.29 1.29 1.29 1.29 1.3 7477 1.0\n", | |
"beta[2] 2.35 4.0e-5 3.3e-3 2.34 2.35 2.35 2.35 2.36 6825 1.0\n", | |
"beta[3] 2.3 3.8e-5 3.4e-3 2.3 2.3 2.3 2.31 2.31 8112 1.0\n", | |
"beta[4] -11.23 4.0e-5 3.3e-3 -11.24 -11.23 -11.23 -11.23 -11.22 6861 1.0\n", | |
"beta[5] -1.86 4.1e-5 3.3e-3 -1.87 -1.86 -1.86 -1.86 -1.85 6481 1.0\n", | |
"beta[6] 22.09 4.1e-5 3.3e-3 22.08 22.09 22.09 22.09 22.1 6622 1.0\n", | |
"beta[7] -6.38 3.9e-5 3.3e-3 -6.39 -6.38 -6.38 -6.38 -6.38 6902 1.0\n", | |
"beta[8] 4.62 4.0e-5 3.4e-3 4.61 4.62 4.62 4.62 4.63 7132 1.0\n", | |
"sigma 0.17 3.1e-5 2.3e-3 0.16 0.17 0.17 0.17 0.17 5595 1.0\n", | |
"lp__ 3206.0 0.05 2.21 3200.9 3204.7 3206.3 3207.6 3209.3 1763 1.0\n", | |
"\n", | |
"Samples were drawn using NUTS at Mon Feb 24 19:15:14 2020.\n", | |
"For each parameter, n_eff is a crude measure of effective sample size,\n", | |
"and Rhat is the potential scale reduction factor on split chains (at \n", | |
"convergence, Rhat=1).\n" | |
] | |
} | |
], | |
"source": [ | |
"print(stan_results)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"step: 0, ELBO loss: -490289.81\n", | |
"step: 5000, ELBO loss: -28303.55\n", | |
"step: 10000, ELBO loss: -21660.62\n", | |
"step: 15000, ELBO loss: 1676.37\n", | |
"step: 20000, ELBO loss: 1663.49\n", | |
"step: 25000, ELBO loss: 1665.18\n" | |
] | |
} | |
], | |
"source": [ | |
"#VI: ELBO maximization with SVI\n", | |
"\n", | |
"#Optimizer parameters (here we use ADAM)\n", | |
"LEARNING_RATE = 1e-2\n", | |
"NUM_STEPS = 30000 #Number of steps for the optimizer\n", | |
"NUM_SAMPLES = 3000 #Number of sample to generate\n", | |
"\n", | |
"#Define the model in pyro\n", | |
"def model(x, y):\n", | |
" alpha = pyro.sample(\"alpha\", dist.Normal(0.0, 100.0))\n", | |
" beta = pyro.sample(\"beta\", dist.Normal(torch.zeros(P), 10.0))\n", | |
" sigma = pyro.sample(\"sigma\", dist.HalfNormal(10.0)) #To make sigma positive\n", | |
" mu = alpha + x @ beta\n", | |
" return pyro.sample(\"y\", dist.Normal(mu, sigma), obs=y)\n", | |
"\n", | |
"#A guide is the variational function for the model, this can be flexibly defined\n", | |
"#in pyro, but we use an automatic guide here to perform mean field Gaussian in\n", | |
"#the latent space\n", | |
"guide = autoguide.AutoDiagonalNormal(model)\n", | |
"optimiser = pyro.optim.Adam({\"lr\": LEARNING_RATE})\n", | |
"loss = pyro.infer.JitTraceGraph_ELBO() #monitor the loss\n", | |
"svi = pyro.infer.SVI(model, guide, optimiser, loss, num_samples=NUM_SAMPLES)\n", | |
"\n", | |
"losses = np.empty(NUM_STEPS)\n", | |
"\n", | |
"pyro.clear_param_store()\n", | |
"\n", | |
"#Print out the loss to make sure we are improving\n", | |
"for step in range(NUM_STEPS):\n", | |
" losses[step] = svi.step(x, y)\n", | |
" if step % 5000 == 0:\n", | |
" print(f\"step: {step:>5}, ELBO loss: {losses[step]:.2f}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"plt.figure(figsize=(12,8))\n", | |
"plt.plot(range(losses.shape[0]), losses)\n", | |
"plt.title('Loss')\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#Extract sampling results\n", | |
"posterior = svi.run(x, y)\n", | |
"support = posterior.marginal([\"alpha\", \"beta\", \"sigma\"]).support()\n", | |
"\n", | |
"data_dict = {k: np.expand_dims(v.detach().numpy(), 0) for k, v in support.items()}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"alpha\n", | |
"Mean: 45.362644; std: 0.0031976588\n", | |
"beta\n", | |
"Mean: 1.2909813; std: 0.003279712\n", | |
"Mean: 2.3464794; std: 0.003473224\n", | |
"Mean: 2.3078; std: 0.0029695085\n", | |
"Mean: -11.229632; std: 0.0030872645\n", | |
"Mean: -1.8544587; std: 0.0035904588\n", | |
"Mean: 22.08873; std: 0.0030570698\n", | |
"Mean: -6.3798227; std: 0.0034353728\n", | |
"Mean: 4.6170278; std: 0.003123097\n", | |
"sigma\n", | |
"Mean: 0.16929966; std: 0.0023091426\n" | |
] | |
} | |
], | |
"source": [ | |
"for k in data_dict.keys():\n", | |
" d = data_dict[k][0]\n", | |
" print(k)\n", | |
" if k == 'beta':\n", | |
" for _ in range(P): \n", | |
" print(\"Mean: \"+str(np.mean(d[:,_]))+\"; std: \"+str(np.std(d[:,_])))\n", | |
" else: \n", | |
" print(\"Mean: \"+str(np.mean(d))+\"; std: \"+str(np.std(d)))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Task\n", | |
"\n", | |
"1. Vary the NUM_SAMPLES, NUM_STEPS, LEARNING_RATE to see the effect of these parameters on the loss and results\n", | |
"2. Plot the distributions of the sampled results, versus the true parameters\n", | |
"3. Compare MCMC and VI in their sampling results, on both time and accuracy" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment