Created
August 28, 2018 21:21
-
-
Save sadatnfs/f0ffa19b3de1bfd8b9ee93bce9786f7b to your computer and use it in GitHub Desktop.
Simple Bayesian Regression example using Pyro
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 48, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import numpy as np\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"import pyro\n", | |
"from pyro.distributions import Normal\n", | |
"from pyro.infer import SVI, Trace_ELBO\n", | |
"from pyro.optim import Adam\n", | |
"import pyro.distributions as dist\n", | |
"\n", | |
"# for CI testing\n", | |
"smoke_test = ('CI' in os.environ)\n", | |
"pyro.enable_validation(True)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Bayesian Regression \n", | |
"Learning a function of the form:\n", | |
" $$y = wX + b + \\epsilon$$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 45, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"N = 500 # size of toy data\n", | |
"p=2\n", | |
"\n", | |
"## Build a simple linear dataset\n", | |
"def build_linear_dataset(N, p=1, noise_std=0.05, w = 3, b = 1):\n", | |
" X = np.random.rand(N, p)\n", | |
" w = w * np.ones(p)\n", | |
" y = np.matmul(X, w) + np.repeat(b, N) + np.random.normal(0, noise_std, size=N)\n", | |
" y = y.reshape(N, 1)\n", | |
" X, y = torch.tensor(X).type(torch.Tensor), torch.tensor(y).type(torch.Tensor)\n", | |
" data = torch.cat((X, y), 1)\n", | |
" assert data.shape == (N, p + 1)\n", | |
" return data\n", | |
"\n", | |
"## Define our regression model module\n", | |
"class RegressionModel(nn.Module):\n", | |
" def __init__(self, p): \n", | |
" super(RegressionModel, self).__init__()\n", | |
" # p = number of features\n", | |
" # 1 = number of output dimensions\n", | |
" self.linear = nn.Linear(p, 1) # p in and one out\n", | |
"\n", | |
" def forward(self, x):\n", | |
" y_pred = self.linear(x)\n", | |
" return y_pred\n", | |
"\n", | |
"regression_model = RegressionModel(p)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"loc = torch.zeros(1, p)\n", | |
"scale = torch.ones(1, p)\n", | |
"# define a unit normal prior\n", | |
"prior = Normal(loc, scale)\n", | |
"# overload the parameters in the regression module with samples from the prior\n", | |
"lifted_module = pyro.random_module(\"regression_module\", regression_model, prior)\n", | |
"# sample a regressor from the prior\n", | |
"sampled_reg_model = lifted_module()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def model(data):\n", | |
" # Create unit normal priors over the parameters\n", | |
" loc, scale = torch.zeros(1, p), 10 * torch.ones(1, p)\n", | |
" bias_loc, bias_scale = torch.zeros(1), 10 * torch.ones(1)\n", | |
" w_prior = Normal(loc, scale).independent(p)\n", | |
" b_prior = Normal(bias_loc, bias_scale).independent(1)\n", | |
" priors = {'linear.weight': w_prior, 'linear.bias': b_prior}\n", | |
" # lift module parameters to random variables sampled from the priors\n", | |
" lifted_module = pyro.random_module(\"module\", regression_model, priors)\n", | |
" # sample a regressor (which also samples w and b)\n", | |
" lifted_reg_model = lifted_module()\n", | |
" with pyro.iarange(\"map\", N):\n", | |
" x_data = data[:, :-1]\n", | |
" y_data = data[:, -1]\n", | |
"\n", | |
" # run the regressor forward conditioned on data\n", | |
" prediction_mean = lifted_reg_model(x_data).squeeze(-1)\n", | |
" # condition on the observed data\n", | |
" pyro.sample(\"obs\",\n", | |
" Normal(prediction_mean, 0.1 * torch.ones(data.size(0))),\n", | |
" obs=y_data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"softplus = torch.nn.Softplus()\n", | |
"\n", | |
"def guide(data):\n", | |
" # define our variational parameters\n", | |
" w_loc = torch.randn(1, p)\n", | |
" # note that we initialize our scales to be pretty narrow\n", | |
" w_log_sig = torch.tensor(-3.0 * torch.ones(1, p) + 0.05 * torch.randn(1, 1))\n", | |
" b_loc = torch.randn(1)\n", | |
" b_log_sig = torch.tensor(-3.0 * torch.ones(1) + 0.05 * torch.randn(1))\n", | |
" # register learnable params in the param store\n", | |
" mw_param = pyro.param(\"guide_mean_weight\", w_loc)\n", | |
" sw_param = softplus(pyro.param(\"guide_log_scale_weight\", w_log_sig))\n", | |
" mb_param = pyro.param(\"guide_mean_bias\", b_loc)\n", | |
" sb_param = softplus(pyro.param(\"guide_log_scale_bias\", b_log_sig))\n", | |
" # guide distributions for w and b\n", | |
" w_dist = Normal(mw_param, sw_param).independent(p)\n", | |
" b_dist = Normal(mb_param, sb_param).independent(1)\n", | |
" dists = {'linear.weight': w_dist, 'linear.bias': b_dist}\n", | |
" # overload the parameters in the module with random samples\n", | |
" # from the guide distributions\n", | |
" lifted_module = pyro.random_module(\"module\", regression_model, dists)\n", | |
" # sample a regressor (which also samples w and b)\n", | |
" return lifted_module()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"optim = Adam({\"lr\": 0.05})\n", | |
"svi = SVI(model, guide, optim, loss=Trace_ELBO())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 42, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"num_iterations = 1500\n", | |
"def main():\n", | |
" pyro.clear_param_store()\n", | |
" data = build_linear_dataset(N, p=p, w = np.array([1., 0.5]), b = 4.)\n", | |
" for j in range(num_iterations):\n", | |
" # calculate the loss and take a gradient step\n", | |
" loss = svi.step(data)\n", | |
" if j % 100 == 0:\n", | |
" print(\"[iteration %04d] loss: %.4f\" % (j + 1, loss / float(N)))\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 43, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[iteration 0001] loss: 1236.4488\n", | |
"[iteration 0101] loss: 37.6479\n", | |
"[iteration 0201] loss: 22.9069\n", | |
"[iteration 0301] loss: 10.0312\n", | |
"[iteration 0401] loss: 4.0211\n", | |
"[iteration 0501] loss: 0.7615\n", | |
"[iteration 0601] loss: -0.5694\n", | |
"[iteration 0701] loss: -1.0169\n", | |
"[iteration 0801] loss: -1.1444\n", | |
"[iteration 0901] loss: -1.1337\n", | |
"CPU times: user 3min 59s, sys: 927 ms, total: 4min\n", | |
"Wall time: 15.7 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time \n", | |
"\n", | |
"main()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 46, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"('guide_mean_weight', array([[1.0170887, 0.516914 ]], dtype=float32))\n", | |
"('guide_log_scale_weight', array([[-3.5383508, -3.70919 ]], dtype=float32))\n", | |
"('guide_mean_bias', array([3.9868238], dtype=float32))\n", | |
"('guide_log_scale_bias', array([-4.094978], dtype=float32))\n" | |
] | |
} | |
], | |
"source": [ | |
"for name in pyro.get_param_store().get_all_param_names():\n", | |
" print( (name, pyro.param(name).data.numpy()))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 94, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "TypeError", | |
"evalue": "Can't instantiate abstract class TracePosterior with abstract methods _traces", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-94-9647bbbb5a52>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mabstract_infer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTracePosterior\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mguide\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;31mTypeError\u001b[0m: Can't instantiate abstract class TracePosterior with abstract methods _traces" | |
] | |
} | |
], | |
"source": [ | |
"pyro.infer.abstract_infer.TracePosterior(model, guide, 2)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment