Skip to content

Instantly share code, notes, and snippets.

@fehiepsi
Last active October 26, 2018 07:05
Show Gist options
  • Save fehiepsi/e2cc69bfaa9b00033834756b3092970f to your computer and use it in GitHub Desktop.
Save fehiepsi/e2cc69bfaa9b00033834756b3092970f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.autograd as autograd\n",
"import torch.distributions as torch_dist\n",
"import pyro\n",
"import pyro.distributions as dist\n",
"import pyro.poutine as poutine"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"data = dist.Normal(3, 0.5).sample(torch.Size([1000]))\n",
"\n",
"def model():\n",
" z = pyro.sample(\"z\", dist.Gamma(1, 1))\n",
" pyro.sample(\"obs\", dist.Normal(3, z), obs=data)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"tmp_trace = poutine.trace(model).get_trace()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def fn(zu):\n",
" z = torch_dist.biject_to(torch_dist.constraints.positive)(zu)\n",
" tmp_trace.nodes[\"z\"][\"value\"] = z\n",
" trace_poutine = poutine.trace(poutine.replay(model, tmp_trace))\n",
" trace_poutine()\n",
" trace = trace_poutine.trace\n",
" return trace.log_prob_sum()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"zu = torch.tensor(0., requires_grad=True)\n",
"fn_jit = torch.jit.trace(fn, (zu,))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(-730.8032)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"autograd.grad(fn(zu), (zu,))[0]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(-1729.8032)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"autograd.grad(fn_jit(zu), (zu,))[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### We make the same version to debug: fn1 should be the same as fn"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(-1729.8032)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def fn1(zu):\n",
" z = torch_dist.biject_to(torch_dist.constraints.positive)(zu)\n",
" tmp_trace.nodes[\"z\"][\"value\"] = z\n",
" trace_poutine = poutine.trace(poutine.replay(model, tmp_trace))\n",
" trace_poutine()\n",
" trace = trace_poutine.trace\n",
" probs = []\n",
" for name, site in trace.nodes.items():\n",
" if site[\"type\"] == \"sample\":\n",
" log_p = site[\"fn\"].log_prob(site[\"value\"]).sum()\n",
" probs.append(log_p)\n",
" return probs[0] + probs[1]\n",
"\n",
"zu = torch.tensor(0., requires_grad=True)\n",
"fn1_jit = torch.jit.trace(fn1, (zu,))\n",
"autograd.grad(fn1_jit(zu), (zu,))[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Yes, it is the same. And `jit` gives wrong the result."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### How about not using trace?"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(-730.8032)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def fn2(zu):\n",
" z = torch_dist.biject_to(torch_dist.constraints.positive)(zu)\n",
" p1 = dist.Gamma(1, 1).log_prob(z).sum()\n",
" p2 = dist.Normal(3, z).log_prob(data).sum()\n",
" return p1 + p2\n",
"\n",
"zu = torch.tensor(0., requires_grad=True)\n",
"fn2_jit = torch.jit.trace(fn2, (zu,))\n",
"autograd.grad(fn2_jit(zu), (zu,))[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`jit` gives right result now."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Let's investigate. What is log_prob at z?"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def fn_at_z(zu):\n",
" z = torch_dist.biject_to(torch_dist.constraints.positive)(zu)\n",
" tmp_trace.nodes[\"z\"][\"value\"] = z\n",
" trace_poutine = poutine.trace(poutine.replay(model, tmp_trace))\n",
" trace_poutine()\n",
" trace = trace_poutine.trace\n",
" probs = []\n",
" for name, site in trace.nodes.items():\n",
" if site[\"type\"] == \"sample\":\n",
" log_p = site[\"fn\"].log_prob(site[\"value\"]).sum()\n",
" probs.append(log_p)\n",
" return probs[0]\n",
"\n",
"zu = torch.tensor(0., requires_grad=True)\n",
"fn_at_z_jit = torch.jit.trace(fn_at_z, (zu,))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**nojit**"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(-1.)\n"
]
}
],
"source": [
"print(autograd.grad(fn_at_z(zu), (zu,))[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**jit**"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(-1.)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"autograd.grad(fn_at_z_jit(zu), (zu,))[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Thing should be fine until now, but..."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(-1000.)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def fn1_at_z(zu):\n",
" z = torch_dist.biject_to(torch_dist.constraints.positive)(zu)\n",
" tmp_trace.nodes[\"z\"][\"value\"] = z\n",
" trace_poutine = poutine.trace(poutine.replay(model, tmp_trace))\n",
" trace_poutine()\n",
" trace = trace_poutine.trace\n",
" probs = []\n",
" for name, site in trace.nodes.items():\n",
" if site[\"type\"] == \"sample\":\n",
" log_p = site[\"fn\"].log_prob(site[\"value\"]).sum()\n",
" probs.append(log_p)\n",
" return probs[0] + 0 * probs[1] # !!! should be the same!\n",
"\n",
"zu = torch.tensor(0., requires_grad=True)\n",
"fn1_at_z = torch.jit.trace(fn1_at_z, (zu,))\n",
"autograd.grad(fn1_at_z(zu), (zu,))[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Oppss, grad is off by 1000!!! (the size of the `data`)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@neerajprad
Copy link

With z being 0, is it possible that the probability of pyro.sample("obs", dist.Normal(3, z), obs=data) being close to 0 accounts for instability between the jit and nojit versions? The difference is large but both are essentially 0 on the probability scale. Can we see this difference for other inputs with a higher value for log_prob?

@fehiepsi
Copy link
Author

z = 1 indeed (after transform). You can remove the transform code and set z = 1, but precision is not the point I want to make here.

On probability scale, e^-1 is definitely different to e^-1000. If the number of data is 5000, then grad will be 5000x off!

@neerajprad
Copy link

z = 1 indeed (after transform)

Sorry, missed the transform code, and this indeed is surprising. Does the jit throw any warnings at all?

@fehiepsi
Copy link
Author

There is no warning at all.

@neerajprad
Copy link

Great to know, @fehiepsi! 👍 on the cool detective work. Moving the discussion back to the issue for better visibility.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment