Skip to content

Instantly share code, notes, and snippets.

@fehiepsi
Created February 17, 2019 06:15
Show Gist options
  • Save fehiepsi/cf1fa0cba2bc30cc779d553eaf929a02 to your computer and use it in GitHub Desktop.
Save fehiepsi/cf1fa0cba2bc30cc779d553eaf929a02 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.distributions as torchdist"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def toy_poly():\n",
" \n",
" x = 5 * torch.rand(100, 1) \n",
" linear_op = -3 - 4*x + 1*x**2 \n",
" y = torchdist.Normal(linear_op, 1).sample()\n",
" return x, y\n",
"\n",
"x_train, y_train = toy_poly()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def log_joint_prob(w0, w1, w2, x, y):\n",
" \n",
" prior_w0 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.))\n",
" prior_w1 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.))\n",
" prior_w2 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.))\n",
"\n",
" linear = w0 + w1*x + w2*x**2\n",
" likelihood = torchdist.Normal(linear, torch.ones_like(linear))\n",
" \n",
" return (\n",
" prior_w0.log_prob(w0) +\n",
" prior_w1.log_prob(w1) +\n",
" prior_w2.log_prob(w2) +\n",
" likelihood.log_prob(y).sum()\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"variational_params = {\n",
" \"w0_loc\": torch.nn.Parameter(torch.tensor(0.)),\n",
" \"w0_scale_log\": torch.nn.Parameter(torch.tensor(0.)),\n",
" \"w1_loc\": torch.nn.Parameter(torch.tensor(0.)),\n",
" \"w1_scale_log\": torch.nn.Parameter(torch.tensor(0.)),\n",
" \"w2_loc\": torch.nn.Parameter(torch.tensor(0.)),\n",
" \"w2_scale_log\": torch.nn.Parameter(torch.tensor(0.)),\n",
"}\n",
"\n",
"def variational_model(variational_params):\n",
" \"\"\"\n",
" Variational model q(w; eta)\n",
" arg: variational parameters \"eta\"\n",
" return: w ~ q(w; eta)\n",
" \"\"\"\n",
" w0_q = torchdist.Normal(\n",
" variational_params[\"w0_loc\"],\n",
" torch.exp(variational_params[\"w0_scale_log\"]),\n",
" )\n",
" \n",
" w1_q = torchdist.Normal(\n",
" variational_params[\"w1_loc\"],\n",
" torch.exp(variational_params[\"w1_scale_log\"]),\n",
" )\n",
" \n",
" w2_q = torchdist.Normal(\n",
" variational_params[\"w2_loc\"],\n",
" torch.exp(variational_params[\"w2_scale_log\"]),\n",
" )\n",
" \n",
" return w0_q, w1_q, w2_q"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def kl_divergence(variational_params, x, y):\n",
" w0_q, w1_q, w2_q = variational_model(variational_params)\n",
" \n",
" w0_sample = w0_q.rsample()\n",
" w1_sample = w1_q.rsample() \n",
" w2_sample = w2_q.rsample()\n",
" \n",
" log_joint_prob_value = log_joint_prob(w0_sample, w1_sample, w2_sample, x, y)\n",
" log_variational_prob_value = (\n",
" w0_q.log_prob(w0_sample) +\n",
" w1_q.log_prob(w1_sample) +\n",
" w2_q.log_prob(w2_sample)\n",
" )\n",
" \n",
" return log_variational_prob_value - log_joint_prob_value"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2010.088\n",
"216.45891\n",
"151.53426\n",
"170.11266\n",
"169.32898\n",
"149.42986\n",
"216.23782\n",
"154.24156\n",
"156.65796\n",
"157.64983\n",
"155.47226\n",
"169.41145\n",
"160.99359\n",
"190.32835\n",
"200.18675\n",
"157.82788\n",
"169.57982\n",
"158.5978\n",
"160.85092\n",
"161.54916\n",
"158.56343\n",
"166.40855\n",
"155.73099\n",
"169.33304\n",
"154.00766\n",
"150.02861\n",
"162.02048\n",
"166.96443\n",
"155.15277\n",
"155.5115\n",
"155.15538\n"
]
}
],
"source": [
"optimizer = torch.optim.SGD(params=variational_params.values(), lr=1e-4)\n",
"\n",
"for i in range(9000):\n",
" optimizer.zero_grad()\n",
" loss_value =kl_divergence(variational_params, x_train, y_train)\n",
" loss_value.backward()\n",
" optimizer.step()\n",
" \n",
" if (i+1) % 300 == 0 or (i==0):\n",
" print(loss_value.detach().numpy())"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'w0_loc': Parameter containing:\n",
" tensor(-3.2885, requires_grad=True), 'w0_scale_log': Parameter containing:\n",
" tensor(-2.2080, requires_grad=True), 'w1_loc': Parameter containing:\n",
" tensor(-3.9548, requires_grad=True), 'w1_scale_log': Parameter containing:\n",
" tensor(-3.3054, requires_grad=True), 'w2_loc': Parameter containing:\n",
" tensor(1.0232, requires_grad=True), 'w2_scale_log': Parameter containing:\n",
" tensor(-4.7245, requires_grad=True)}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"variational_params"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment