Last active
October 26, 2018 07:05
-
-
Save fehiepsi/e2cc69bfaa9b00033834756b3092970f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.autograd as autograd\n", | |
"import torch.distributions as torch_dist\n", | |
"import pyro\n", | |
"import pyro.distributions as dist\n", | |
"import pyro.poutine as poutine" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data = dist.Normal(3, 0.5).sample(torch.Size([1000]))\n", | |
"\n", | |
"def model():\n", | |
" z = pyro.sample(\"z\", dist.Gamma(1, 1))\n", | |
" pyro.sample(\"obs\", dist.Normal(3, z), obs=data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"tmp_trace = poutine.trace(model).get_trace()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def fn(zu):\n", | |
" z = torch_dist.biject_to(torch_dist.constraints.positive)(zu)\n", | |
" tmp_trace.nodes[\"z\"][\"value\"] = z\n", | |
" trace_poutine = poutine.trace(poutine.replay(model, tmp_trace))\n", | |
" trace_poutine()\n", | |
" trace = trace_poutine.trace\n", | |
" return trace.log_prob_sum()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"zu = torch.tensor(0., requires_grad=True)\n", | |
"fn_jit = torch.jit.trace(fn, (zu,))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(-730.8032)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"autograd.grad(fn(zu), (zu,))[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(-1729.8032)" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"autograd.grad(fn_jit(zu), (zu,))[0]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### We make the same version to debug: fn1 should be the same as fn" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(-1729.8032)" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"def fn1(zu):\n", | |
" z = torch_dist.biject_to(torch_dist.constraints.positive)(zu)\n", | |
" tmp_trace.nodes[\"z\"][\"value\"] = z\n", | |
" trace_poutine = poutine.trace(poutine.replay(model, tmp_trace))\n", | |
" trace_poutine()\n", | |
" trace = trace_poutine.trace\n", | |
" probs = []\n", | |
" for name, site in trace.nodes.items():\n", | |
" if site[\"type\"] == \"sample\":\n", | |
" log_p = site[\"fn\"].log_prob(site[\"value\"]).sum()\n", | |
" probs.append(log_p)\n", | |
" return probs[0] + probs[1]\n", | |
"\n", | |
"zu = torch.tensor(0., requires_grad=True)\n", | |
"fn1_jit = torch.jit.trace(fn1, (zu,))\n", | |
"autograd.grad(fn1_jit(zu), (zu,))[0]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Yes, it is the same. And `jit` gives wrong the result." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### How about not using trace?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(-730.8032)" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"def fn2(zu):\n", | |
" z = torch_dist.biject_to(torch_dist.constraints.positive)(zu)\n", | |
" p1 = dist.Gamma(1, 1).log_prob(z).sum()\n", | |
" p2 = dist.Normal(3, z).log_prob(data).sum()\n", | |
" return p1 + p2\n", | |
"\n", | |
"zu = torch.tensor(0., requires_grad=True)\n", | |
"fn2_jit = torch.jit.trace(fn2, (zu,))\n", | |
"autograd.grad(fn2_jit(zu), (zu,))[0]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"`jit` gives right result now." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Let's investigate. What is log_prob at z?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def fn_at_z(zu):\n", | |
" z = torch_dist.biject_to(torch_dist.constraints.positive)(zu)\n", | |
" tmp_trace.nodes[\"z\"][\"value\"] = z\n", | |
" trace_poutine = poutine.trace(poutine.replay(model, tmp_trace))\n", | |
" trace_poutine()\n", | |
" trace = trace_poutine.trace\n", | |
" probs = []\n", | |
" for name, site in trace.nodes.items():\n", | |
" if site[\"type\"] == \"sample\":\n", | |
" log_p = site[\"fn\"].log_prob(site[\"value\"]).sum()\n", | |
" probs.append(log_p)\n", | |
" return probs[0]\n", | |
"\n", | |
"zu = torch.tensor(0., requires_grad=True)\n", | |
"fn_at_z_jit = torch.jit.trace(fn_at_z, (zu,))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**nojit**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tensor(-1.)\n" | |
] | |
} | |
], | |
"source": [ | |
"print(autograd.grad(fn_at_z(zu), (zu,))[0])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**jit**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(-1.)" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"autograd.grad(fn_at_z_jit(zu), (zu,))[0]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Thing should be fine until now, but..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(-1000.)" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"def fn1_at_z(zu):\n", | |
" z = torch_dist.biject_to(torch_dist.constraints.positive)(zu)\n", | |
" tmp_trace.nodes[\"z\"][\"value\"] = z\n", | |
" trace_poutine = poutine.trace(poutine.replay(model, tmp_trace))\n", | |
" trace_poutine()\n", | |
" trace = trace_poutine.trace\n", | |
" probs = []\n", | |
" for name, site in trace.nodes.items():\n", | |
" if site[\"type\"] == \"sample\":\n", | |
" log_p = site[\"fn\"].log_prob(site[\"value\"]).sum()\n", | |
" probs.append(log_p)\n", | |
" return probs[0] + 0 * probs[1] # !!! should be the same!\n", | |
"\n", | |
"zu = torch.tensor(0., requires_grad=True)\n", | |
"fn1_at_z = torch.jit.trace(fn1_at_z, (zu,))\n", | |
"autograd.grad(fn1_at_z(zu), (zu,))[0]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Oppss, grad is off by 1000!!! (the size of the `data`)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
z = 1
indeed (after transform). You can remove the transform code and set z = 1
, but precision is not the point I want to make here.
On probability scale, e^-1
is definitely different to e^-1000
. If the number of data is 5000, then grad will be 5000x off!
z = 1 indeed (after transform)
Sorry, missed the transform code, and this indeed is surprising. Does the jit throw any warnings at all?
There is no warning at all.
Great to know, @fehiepsi! 👍 on the cool detective work. Moving the discussion back to the issue for better visibility.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
With z being 0, is it possible that the probability of
pyro.sample("obs", dist.Normal(3, z), obs=data)
being close to 0 accounts for instability between the jit and nojit versions? The difference is large but both are essentially 0 on the probability scale. Can we see this difference for other inputs with a higher value forlog_prob
?