Created
February 17, 2019 06:15
-
-
Save fehiepsi/cf1fa0cba2bc30cc779d553eaf929a02 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.distributions as torchdist" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def toy_poly():\n", | |
" \n", | |
" x = 5 * torch.rand(100, 1) \n", | |
" linear_op = -3 - 4*x + 1*x**2 \n", | |
" y = torchdist.Normal(linear_op, 1).sample()\n", | |
" return x, y\n", | |
"\n", | |
"x_train, y_train = toy_poly()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def log_joint_prob(w0, w1, w2, x, y):\n", | |
" \n", | |
" prior_w0 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.))\n", | |
" prior_w1 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.))\n", | |
" prior_w2 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.))\n", | |
"\n", | |
" linear = w0 + w1*x + w2*x**2\n", | |
" likelihood = torchdist.Normal(linear, torch.ones_like(linear))\n", | |
" \n", | |
" return (\n", | |
" prior_w0.log_prob(w0) +\n", | |
" prior_w1.log_prob(w1) +\n", | |
" prior_w2.log_prob(w2) +\n", | |
" likelihood.log_prob(y).sum()\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"variational_params = {\n", | |
" \"w0_loc\": torch.nn.Parameter(torch.tensor(0.)),\n", | |
" \"w0_scale_log\": torch.nn.Parameter(torch.tensor(0.)),\n", | |
" \"w1_loc\": torch.nn.Parameter(torch.tensor(0.)),\n", | |
" \"w1_scale_log\": torch.nn.Parameter(torch.tensor(0.)),\n", | |
" \"w2_loc\": torch.nn.Parameter(torch.tensor(0.)),\n", | |
" \"w2_scale_log\": torch.nn.Parameter(torch.tensor(0.)),\n", | |
"}\n", | |
"\n", | |
"def variational_model(variational_params):\n", | |
" \"\"\"\n", | |
" Variational model q(w; eta)\n", | |
" arg: variational parameters \"eta\"\n", | |
" return: w ~ q(w; eta)\n", | |
" \"\"\"\n", | |
" w0_q = torchdist.Normal(\n", | |
" variational_params[\"w0_loc\"],\n", | |
" torch.exp(variational_params[\"w0_scale_log\"]),\n", | |
" )\n", | |
" \n", | |
" w1_q = torchdist.Normal(\n", | |
" variational_params[\"w1_loc\"],\n", | |
" torch.exp(variational_params[\"w1_scale_log\"]),\n", | |
" )\n", | |
" \n", | |
" w2_q = torchdist.Normal(\n", | |
" variational_params[\"w2_loc\"],\n", | |
" torch.exp(variational_params[\"w2_scale_log\"]),\n", | |
" )\n", | |
" \n", | |
" return w0_q, w1_q, w2_q" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def kl_divergence(variational_params, x, y):\n", | |
" w0_q, w1_q, w2_q = variational_model(variational_params)\n", | |
" \n", | |
" w0_sample = w0_q.rsample()\n", | |
" w1_sample = w1_q.rsample() \n", | |
" w2_sample = w2_q.rsample()\n", | |
" \n", | |
" log_joint_prob_value = log_joint_prob(w0_sample, w1_sample, w2_sample, x, y)\n", | |
" log_variational_prob_value = (\n", | |
" w0_q.log_prob(w0_sample) +\n", | |
" w1_q.log_prob(w1_sample) +\n", | |
" w2_q.log_prob(w2_sample)\n", | |
" )\n", | |
" \n", | |
" return log_variational_prob_value - log_joint_prob_value" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2010.088\n", | |
"216.45891\n", | |
"151.53426\n", | |
"170.11266\n", | |
"169.32898\n", | |
"149.42986\n", | |
"216.23782\n", | |
"154.24156\n", | |
"156.65796\n", | |
"157.64983\n", | |
"155.47226\n", | |
"169.41145\n", | |
"160.99359\n", | |
"190.32835\n", | |
"200.18675\n", | |
"157.82788\n", | |
"169.57982\n", | |
"158.5978\n", | |
"160.85092\n", | |
"161.54916\n", | |
"158.56343\n", | |
"166.40855\n", | |
"155.73099\n", | |
"169.33304\n", | |
"154.00766\n", | |
"150.02861\n", | |
"162.02048\n", | |
"166.96443\n", | |
"155.15277\n", | |
"155.5115\n", | |
"155.15538\n" | |
] | |
} | |
], | |
"source": [ | |
"optimizer = torch.optim.SGD(params=variational_params.values(), lr=1e-4)\n", | |
"\n", | |
"for i in range(9000):\n", | |
" optimizer.zero_grad()\n", | |
" loss_value =kl_divergence(variational_params, x_train, y_train)\n", | |
" loss_value.backward()\n", | |
" optimizer.step()\n", | |
" \n", | |
" if (i+1) % 300 == 0 or (i==0):\n", | |
" print(loss_value.detach().numpy())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'w0_loc': Parameter containing:\n", | |
" tensor(-3.2885, requires_grad=True), 'w0_scale_log': Parameter containing:\n", | |
" tensor(-2.2080, requires_grad=True), 'w1_loc': Parameter containing:\n", | |
" tensor(-3.9548, requires_grad=True), 'w1_scale_log': Parameter containing:\n", | |
" tensor(-3.3054, requires_grad=True), 'w2_loc': Parameter containing:\n", | |
" tensor(1.0232, requires_grad=True), 'w2_scale_log': Parameter containing:\n", | |
" tensor(-4.7245, requires_grad=True)}" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"variational_params" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment