Skip to content

Instantly share code, notes, and snippets.

@cs224
Created July 10, 2020 05:58
Show Gist options
  • Save cs224/8259c737a41fc992b34a489daff2cec4 to your computer and use it in GitHub Desktop.
Save cs224/8259c737a41fc992b34a489daff2cec4 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cs224 \n",
"last updated: 2020-07-10 \n",
"\n",
"CPython 3.6.10\n",
"IPython 7.15.0\n",
"\n",
"numpy 1.18.5\n",
"scipy 1.5.0\n",
"pandas 1.0.5\n",
"matplotlib 3.2.2\n",
"seaborn 0.10.1\n",
"pymc3 3.9.2\n",
"statsmodels 0.11.1\n",
"simdkalman 1.0.1\n",
"tensorflow 2.2.0\n",
"tensorflow_probability 0.10.1\n",
"arviz 0.9.0\n"
]
}
],
"source": [
"%load_ext watermark\n",
"%watermark -a 'cs224' -u -d -v -p numpy,scipy,pandas,matplotlib,seaborn,pymc3,statsmodels,simdkalman,tensorflow,tensorflow_probability,arviz"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style>.container { width:70% !important; }</style>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%matplotlib inline\n",
"import numpy as np, scipy, scipy.stats as stats, scipy.special, scipy.misc, pandas as pd, matplotlib.pyplot as plt, seaborn as sns\n",
"import matplotlib as mpl\n",
"import sympy\n",
"\n",
"pd.set_option('display.max_columns', 500)\n",
"pd.set_option('display.width', 1000)\n",
"# pd.set_option('display.float_format', lambda x: '%.2f' % x)\n",
"#np.set_printoptions(edgeitems=50)\n",
"np.set_printoptions(suppress=True)\n",
"np.set_printoptions(edgeitems=30, linewidth=100000) #, formatter=dict(float=lambda x: \"%.3g\" % x)\n",
"np.core.arrayprint._line_width = 500\n",
"\n",
"sns.set()\n",
"\n",
"from IPython.display import display, HTML\n",
"\n",
"display(HTML(\"<style>.container { width:70% !important; }</style>\"))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow.compat.v2 as tf\n",
"import tensorflow_probability as tfp\n",
"import arviz as az\n",
"from functools import partial"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tensorflow Probability TFP"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* [NUTS](https://adamhaber.github.io/post/nuts/)\n",
"* [Tour of probabilistic programming language APIs](https://colcarroll.github.io/ppl-api/)\n",
"* [Reasoning about Shapes and Probability Distributions](https://ericmjl.github.io/blog/2019/5/29/reasoning-about-shapes-and-probability-distributions/)\n",
" * [Examples](https://colab.research.google.com/github/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Understanding_TensorFlow_Distributions_Shapes.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define the Sampling Code\n",
"\n",
"Taken from here: https://adamhaber.github.io/post/nuts/"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# using pymc3 naming conventions, with log_likelihood instead of lp so that ArviZ can compute loo and waic\n",
"sample_stats_name = ['log_likelihood','tree_size','diverging','energy','mean_tree_accept']\n",
"\n",
"def tfp_trace_to_arviz(tfp_trace, var_names=None, sample_stats_name=sample_stats_name):\n",
" samps, trace = tfp_trace\n",
" if var_names is None:\n",
" var_names = [\"var \" + str(x) for x in range(len(samps))]\n",
" \n",
" sample_stats = {k:v.numpy().T for k, v in zip(sample_stats_name, trace)}\n",
" posterior = {name : tf.transpose(samp, [1, 0, 2]).numpy() for name, samp in zip(var_names, samps)}\n",
" return az.from_dict(posterior=posterior, sample_stats=sample_stats)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def trace_fn(_, pkr): \n",
" return (\n",
" pkr.inner_results.inner_results.target_log_prob,\n",
" pkr.inner_results.inner_results.leapfrogs_taken,\n",
" pkr.inner_results.inner_results.has_divergence,\n",
" pkr.inner_results.inner_results.energy,\n",
" pkr.inner_results.inner_results.log_accept_ratio\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"n_chains = 10\n",
"\n",
"def run_nuts_template(trace_fn, target_log_prob_fn, inits, bijectors_list=None, num_steps=500, num_burnin=500, n_chains=n_chains): \n",
" step_size = np.random.rand(n_chains, 1)*.5 + 1.\n",
" \n",
" if not isinstance(inits, list):\n",
" inits = [inits]\n",
" \n",
" if bijectors_list is None:\n",
" bijectors_list = [tfp.bijectors.Identity()]*len(inits)\n",
"\n",
" kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(\n",
" tfp.mcmc.TransformedTransitionKernel(\n",
" inner_kernel=tfp.mcmc.NoUTurnSampler(\n",
" target_log_prob_fn,\n",
" step_size=[step_size]*len(inits)\n",
" ),\n",
" bijector=bijectors_list\n",
" ),\n",
" target_accept_prob=.8,\n",
" num_adaptation_steps=int(0.8*num_burnin),\n",
" step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(\n",
" inner_results=pkr.inner_results._replace(step_size=new_step_size)\n",
" ),\n",
" step_size_getter_fn=lambda pkr: pkr.inner_results.step_size,\n",
" log_accept_prob_getter_fn=lambda pkr: pkr.inner_results.log_accept_ratio,\n",
" )\n",
" \n",
" res = tfp.mcmc.sample_chain(\n",
" num_results=num_steps,\n",
" num_burnin_steps=num_burnin,\n",
" current_state=inits,\n",
" kernel=kernel,\n",
" trace_fn=trace_fn\n",
" )\n",
" return res"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"run_nuts = partial(run_nuts_template, trace_fn)\n",
"\n",
"run_nuts_opt = tf.function(run_nuts)\n",
"run_nuts_defun = tf.function(run_nuts, autograph=False)\n",
"run_nuts_defun_xla = tf.function(run_nuts, autograph=False, experimental_compile=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define the test data\n",
"\n",
"Normal distribution with μ=-3.0 and σ=5.0"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DescribeResult(nobs=10000, minmax=(-22.920290192922792, 17.510972306729627), mean=-2.9322397586814573, variance=25.254547939735094, skewness=0.017851286422333672, kurtosis=-0.07929113135882782)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = scipy.stats.norm(loc=-3.0, scale=5).rvs(10000)\n",
"scipy.stats.describe(data)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"dtype=tf.float32"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>y</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>10000.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>-2.932240</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>5.025390</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>-22.920290</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>-6.332816</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>-2.980136</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>0.504405</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>17.510972</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" y\n",
"count 10000.000000\n",
"mean -2.932240\n",
"std 5.025390\n",
"min -22.920290\n",
"25% -6.332816\n",
"50% -2.980136\n",
"75% 0.504405\n",
"max 17.510972"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.DataFrame(data, columns = ['y'])\n",
"df.describe()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define the model via JointDistributionCoroutine"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def simple_normal():\n",
" μ = yield tfp.distributions.JointDistributionCoroutine.Root(tfp.distributions.Sample(tfp.distributions.Normal(loc=0.0, scale=100.0, name=\"mu\"), 1))\n",
" σ = yield tfp.distributions.JointDistributionCoroutine.Root(tfp.distributions.Sample(tfp.distributions.HalfCauchy(loc=0.0, scale=1.0, name=\"sigma\"), 1)) \n",
" y = yield tfp.distributions.Independent(tfp.distributions.Normal(loc=μ, scale=σ))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"simple_normal_jd = tfp.distributions.JointDistributionCoroutine(simple_normal)\n",
"simple_normal_log_prob = lambda *args: simple_normal_jd.log_prob(args + (tf.cast(df['y'],dtype),))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tfp.distributions.JointDistributionCoroutine 'JointDistributionCoroutine' batch_shape=([], [], [1]) event_shape=([1], [1], []) dtype=(float32, float32, float32)>"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"simple_normal_jd"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<function __main__.<lambda>(*args)>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"simple_normal_log_prob"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(10000,), dtype=float32, numpy=array([ -691.8899 , -152.8568 , -13.232361 , -707.83527 , -172.60931 , -993.58655 , -41.972393 , -22.376057 , -167.41461 , -34.78069 , -10.01722 , -41.879856 , -505.69324 , -1203.8346 , -14.015528 , -205.84846 , -143.15659 , -623.37177 , -567.86224 , -26.278309 , -1012.9196 , -1236.1648 , -1284.6156 , -605.8612 , -1159.5087 , -1185.9105 , -194.47389 , -15.63607 , -183.93636 , -19.521341 , ..., -1012.02954 , -273.9243 , -63.970608 , -5.4297695, -52.786026 , -12.76102 , -20.700647 , -243.86974 , -166.18301 , -962.8709 , -384.02362 , -1175.0737 , -468.0815 , -1050.8275 , -5.730259 , -685.54553 , -655.22656 , -19.387684 , -67.51332 , -768.34314 , -16.29623 , -169.35503 , -1485.6156 , -2207.6409 , -769.7052 , -641.5705 , -22.50534 , -340.9507 , -111.759346 , -1218.7783 ], dtype=float32)>"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mu = 0.1\n",
"sigma = 0.2\n",
"simple_normal_log_prob(mu, sigma)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 31.2 s, sys: 5.73 s, total: 36.9 s\n",
"Wall time: 10.1 s\n"
]
}
],
"source": [
"%%time\n",
"res = run_nuts_defun(simple_normal_log_prob, [tf.ones((n_chains, 1)), tf.ones((n_chains, 1))])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"trace1 = tfp_trace_to_arviz(res,['mu','sigma'])"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mean</th>\n",
" <th>sd</th>\n",
" <th>hdi_3%</th>\n",
" <th>hdi_97%</th>\n",
" <th>mcse_mean</th>\n",
" <th>mcse_sd</th>\n",
" <th>ess_mean</th>\n",
" <th>ess_sd</th>\n",
" <th>ess_bulk</th>\n",
" <th>ess_tail</th>\n",
" <th>r_hat</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>mu[0]</th>\n",
" <td>-2.932</td>\n",
" <td>0.050</td>\n",
" <td>-3.024</td>\n",
" <td>-2.840</td>\n",
" <td>0.001</td>\n",
" <td>0.001</td>\n",
" <td>4399.0</td>\n",
" <td>4399.0</td>\n",
" <td>4397.0</td>\n",
" <td>3101.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>sigma[0]</th>\n",
" <td>5.026</td>\n",
" <td>0.036</td>\n",
" <td>4.959</td>\n",
" <td>5.092</td>\n",
" <td>0.001</td>\n",
" <td>0.000</td>\n",
" <td>4659.0</td>\n",
" <td>4658.0</td>\n",
" <td>4664.0</td>\n",
" <td>3758.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_mean ess_sd ess_bulk ess_tail r_hat\n",
"mu[0] -2.932 0.050 -3.024 -2.840 0.001 0.001 4399.0 4399.0 4397.0 3101.0 1.0\n",
"sigma[0] 5.026 0.036 4.959 5.092 0.001 0.000 4659.0 4658.0 4664.0 3758.0 1.0"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"az.summary(trace1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define the model via JointDistributionSequential"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"simple_normal_model = tfp.distributions.JointDistributionSequential(\n",
" [\n",
" tfp.distributions.Sample(tfp.distributions.Normal(loc=0.0, scale=100.0, name=\"mu\"), 1),\n",
" tfp.distributions.Sample(tfp.distributions.HalfCauchy(loc=0.0, scale=1.0, name=\"sigma\"), 1),\n",
" lambda sigma, mu: # reverse order!\n",
" tfp.distributions.Independent(tfp.distributions.Normal(loc=mu, scale=sigma)) \n",
" ]\n",
")\n",
"simple_normal_model_log_prob = lambda *args: simple_normal_model.log_prob(args + (tf.cast(df['y'],dtype),))"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tfp.distributions.JointDistributionSequential 'JointDistributionSequential' batch_shape=[[], [], [1]] event_shape=[[1], [1], []] dtype=[float32, float32, float32]>"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"simple_normal_model"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<function __main__.<lambda>(*args)>"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"simple_normal_model_log_prob"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(10000,), dtype=float32, numpy=array([ -691.8899 , -152.8568 , -13.232361 , -707.83527 , -172.60931 , -993.58655 , -41.972393 , -22.376057 , -167.41461 , -34.78069 , -10.01722 , -41.879856 , -505.69324 , -1203.8346 , -14.015528 , -205.84846 , -143.15659 , -623.37177 , -567.86224 , -26.278309 , -1012.9196 , -1236.1648 , -1284.6156 , -605.8612 , -1159.5087 , -1185.9105 , -194.47389 , -15.63607 , -183.93636 , -19.521341 , ..., -1012.02954 , -273.9243 , -63.970608 , -5.4297695, -52.786026 , -12.76102 , -20.700647 , -243.86974 , -166.18301 , -962.8709 , -384.02362 , -1175.0737 , -468.0815 , -1050.8275 , -5.730259 , -685.54553 , -655.22656 , -19.387684 , -67.51332 , -768.34314 , -16.29623 , -169.35503 , -1485.6156 , -2207.6409 , -769.7052 , -641.5705 , -22.50534 , -340.9507 , -111.759346 , -1218.7783 ], dtype=float32)>"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"simple_normal_model_log_prob(mu, sigma)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 33.3 s, sys: 6.04 s, total: 39.3 s\n",
"Wall time: 10.8 s\n"
]
}
],
"source": [
"%%time\n",
"res = run_nuts_defun(simple_normal_model_log_prob, [tf.ones((n_chains, 1)), tf.ones((n_chains, 1))])"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"trace2 = tfp_trace_to_arviz(res,['mu','sigma'])"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mean</th>\n",
" <th>sd</th>\n",
" <th>hdi_3%</th>\n",
" <th>hdi_97%</th>\n",
" <th>mcse_mean</th>\n",
" <th>mcse_sd</th>\n",
" <th>ess_mean</th>\n",
" <th>ess_sd</th>\n",
" <th>ess_bulk</th>\n",
" <th>ess_tail</th>\n",
" <th>r_hat</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>mu[0]</th>\n",
" <td>-2.934</td>\n",
" <td>0.051</td>\n",
" <td>-3.028</td>\n",
" <td>-2.838</td>\n",
" <td>0.001</td>\n",
" <td>0.001</td>\n",
" <td>4357.0</td>\n",
" <td>4355.0</td>\n",
" <td>4351.0</td>\n",
" <td>3178.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>sigma[0]</th>\n",
" <td>5.026</td>\n",
" <td>0.035</td>\n",
" <td>4.960</td>\n",
" <td>5.091</td>\n",
" <td>0.001</td>\n",
" <td>0.000</td>\n",
" <td>4723.0</td>\n",
" <td>4722.0</td>\n",
" <td>4730.0</td>\n",
" <td>4197.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_mean ess_sd ess_bulk ess_tail r_hat\n",
"mu[0] -2.934 0.051 -3.028 -2.838 0.001 0.001 4357.0 4355.0 4351.0 3178.0 1.0\n",
"sigma[0] 5.026 0.035 4.960 5.091 0.001 0.000 4723.0 4722.0 4730.0 4197.0 1.0"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"az.summary(trace2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define the model via log_prob function"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"# @tf.function\n",
"def target_log_prob_fn(mu, sigma):\n",
" mu_dist = tfp.distributions.Normal(loc=0.0, scale=100.0, name=\"mu\")\n",
" mu_prob = tf.reduce_sum(mu_dist.log_prob(mu))\n",
" sigma_dist = tfp.distributions.HalfCauchy(loc=0.0, scale=1.0, name=\"sigma\")\n",
" sigma_prob = tf.reduce_sum(sigma_dist.log_prob(sigma))\n",
" y_dist = tfp.distributions.Normal(loc=mu, scale=sigma)\n",
" y_dist_prob = tf.reduce_sum(y_dist.log_prob(tf.cast(df['y'],dtype)))\n",
" return mu_prob + sigma_prob + y_dist_prob"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<function __main__.target_log_prob_fn(mu, sigma)>"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"target_log_prob_fn"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(), dtype=float32, numpy=-4298913.0>"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"target_log_prob_fn(mu, sigma)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 34.3 s, sys: 6 s, total: 40.3 s\n",
"Wall time: 10.7 s\n"
]
}
],
"source": [
"%%time\n",
"res = run_nuts_defun(target_log_prob_fn, [tf.ones((n_chains, 1)), tf.ones((n_chains, 1))])"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"trace3 = tfp_trace_to_arviz(res,['mu','sigma'])"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mean</th>\n",
" <th>sd</th>\n",
" <th>hdi_3%</th>\n",
" <th>hdi_97%</th>\n",
" <th>mcse_mean</th>\n",
" <th>mcse_sd</th>\n",
" <th>ess_mean</th>\n",
" <th>ess_sd</th>\n",
" <th>ess_bulk</th>\n",
" <th>ess_tail</th>\n",
" <th>r_hat</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>mu[0]</th>\n",
" <td>-2.932</td>\n",
" <td>0.050</td>\n",
" <td>-3.020</td>\n",
" <td>-2.831</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>13517.0</td>\n",
" <td>13480.0</td>\n",
" <td>13502.0</td>\n",
" <td>3670.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>sigma[0]</th>\n",
" <td>5.025</td>\n",
" <td>0.035</td>\n",
" <td>4.965</td>\n",
" <td>5.095</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>12171.0</td>\n",
" <td>12171.0</td>\n",
" <td>12230.0</td>\n",
" <td>4367.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_mean ess_sd ess_bulk ess_tail r_hat\n",
"mu[0] -2.932 0.050 -3.020 -2.831 0.0 0.0 13517.0 13480.0 13502.0 3670.0 1.0\n",
"sigma[0] 5.025 0.035 4.965 5.095 0.0 0.0 12171.0 12171.0 12230.0 4367.0 1.0"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"az.summary(trace3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define the model via log_prob function: problematic, my original attempt\n",
"\n",
"I do not understand what causes the problem below, because the output of `target_log_prob_fn_problematic` is exactly the same as `simple_normal_log_prob` and `simple_normal_model_log_prob`."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"# @tf.function\n",
"def target_log_prob_fn_problematic(mu, sigma):\n",
" mu_dist = tfp.distributions.Normal(loc=0.0, scale=100.0, name=\"mu\")\n",
" mu_prob = mu_dist.log_prob(mu)\n",
" sigma_dist = tfp.distributions.HalfCauchy(loc=0.0, scale=1.0, name=\"sigma\")\n",
" sigma_prob = sigma_dist.log_prob(sigma)\n",
" y_dist = tfp.distributions.Normal(loc=mu, scale=sigma)\n",
" y_dist_prob = y_dist.log_prob(tf.cast(df['y'],dtype))\n",
" return mu_prob + sigma_prob + y_dist_prob"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<function __main__.target_log_prob_fn_problematic(mu, sigma)>"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"target_log_prob_fn_problematic"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(10000,), dtype=float32, numpy=array([ -691.8899 , -152.8568 , -13.232361 , -707.83527 , -172.60931 , -993.58655 , -41.972393 , -22.376057 , -167.41461 , -34.78069 , -10.01722 , -41.879856 , -505.69324 , -1203.8346 , -14.015528 , -205.84846 , -143.15659 , -623.37177 , -567.86224 , -26.278309 , -1012.9196 , -1236.1648 , -1284.6156 , -605.8612 , -1159.5087 , -1185.9105 , -194.47389 , -15.63607 , -183.93636 , -19.521341 , ..., -1012.02954 , -273.9243 , -63.970608 , -5.4297695, -52.786026 , -12.76102 , -20.700647 , -243.86974 , -166.18301 , -962.8709 , -384.02362 , -1175.0737 , -468.0815 , -1050.8275 , -5.730259 , -685.54553 , -655.22656 , -19.387684 , -67.51332 , -768.34314 , -16.29623 , -169.35503 , -1485.6156 , -2207.6409 , -769.7052 , -641.5705 , -22.50534 , -340.9507 , -111.759346 , -1218.7783 ], dtype=float32)>"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"target_log_prob_fn_problematic(mu, sigma)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"scrolled": false
},
"outputs": [
{
"ename": "InvalidArgumentError",
"evalue": "The inner 2 dimensions of output.shape=[11,10,1] must match the inner 2 dimensions of updates.shape=[1,10,10000] [Op:TensorScatterUpdate]",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31m_FallbackException\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py\u001b[0m in \u001b[0;36mtensor_scatter_update\u001b[0;34m(tensor, indices, updates, name)\u001b[0m\n\u001b[1;32m 10806\u001b[0m \u001b[0m_ctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_context_handle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtld\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"TensorScatterUpdate\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m> 10807\u001b[0;31m tld.op_callbacks, tensor, indices, updates)\n\u001b[0m\u001b[1;32m 10808\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_result\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31m_FallbackException\u001b[0m: This function does not handle the case of the path where all inputs are not already EagerTensors.",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[0;31mInvalidArgumentError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-35-c912a2d07082>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrun_nuts\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget_log_prob_fn_problematic\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_chains\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_chains\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-6-db6db34d1737>\u001b[0m in \u001b[0;36mrun_nuts_template\u001b[0;34m(trace_fn, target_log_prob_fn, inits, bijectors_list, num_steps, num_burnin, n_chains)\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0mcurrent_state\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minits\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0mkernel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkernel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 34\u001b[0;31m \u001b[0mtrace_fn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrace_fn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 35\u001b[0m )\n\u001b[1;32m 36\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/sample.py\u001b[0m in \u001b[0;36msample_chain\u001b[0;34m(num_results, current_state, previous_kernel_results, kernel, num_burnin_steps, num_steps_between_results, trace_fn, return_final_kernel_results, parallel_iterations, name)\u001b[0m\n\u001b[1;32m 357\u001b[0m trace_fn(*state_and_results)),\n\u001b[1;32m 358\u001b[0m \u001b[0;31m# pylint: enable=g-long-lambda\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 359\u001b[0;31m parallel_iterations=parallel_iterations)\n\u001b[0m\u001b[1;32m 360\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_final_kernel_results\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/internal/util.py\u001b[0m in \u001b[0;36mtrace_scan\u001b[0;34m(loop_fn, initial_state, elems, trace_fn, parallel_iterations, name)\u001b[0m\n\u001b[1;32m 392\u001b[0m \u001b[0mbody\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_body\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 393\u001b[0m \u001b[0mloop_vars\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minitial_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace_arrays\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 394\u001b[0;31m parallel_iterations=parallel_iterations)\n\u001b[0m\u001b[1;32m 395\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 396\u001b[0m \u001b[0mstacked_trace\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmap_structure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace_arrays\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py\u001b[0m in \u001b[0;36mnew_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 572\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__module__\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'in a future version'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 573\u001b[0m if date is None else ('after %s' % date), instructions)\n\u001b[0;32m--> 574\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 575\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 576\u001b[0m doc = _add_deprecated_arg_value_notice_to_docstring(\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py\u001b[0m in \u001b[0;36mwhile_loop_v2\u001b[0;34m(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)\u001b[0m\n\u001b[1;32m 2489\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2490\u001b[0m \u001b[0mmaximum_iterations\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmaximum_iterations\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2491\u001b[0;31m return_same_structure=True)\n\u001b[0m\u001b[1;32m 2492\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2493\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py\u001b[0m in \u001b[0;36mwhile_loop\u001b[0;34m(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)\u001b[0m\n\u001b[1;32m 2725\u001b[0m list(loop_vars))\n\u001b[1;32m 2726\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mcond\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mloop_vars\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2727\u001b[0;31m \u001b[0mloop_vars\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbody\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mloop_vars\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2728\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtry_to_pack\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloop_vars\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_basetuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2729\u001b[0m \u001b[0mpacked\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/internal/util.py\u001b[0m in \u001b[0;36m_body\u001b[0;34m(i, state, trace_arrays)\u001b[0m\n\u001b[1;32m 381\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 382\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_body\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace_arrays\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 383\u001b[0;31m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloop_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0melems_array\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 384\u001b[0m trace_arrays = tf.nest.pack_sequence_as(trace_arrays, [\n\u001b[1;32m 385\u001b[0m a.write(i, v) for a, v in zip(\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/sample.py\u001b[0m in \u001b[0;36m_trace_scan_fn\u001b[0;34m(state_and_results, num_steps)\u001b[0m\n\u001b[1;32m 341\u001b[0m \u001b[0mbody_fn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkernel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mone_step\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 342\u001b[0m \u001b[0minitial_loop_vars\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate_and_results\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 343\u001b[0;31m parallel_iterations=parallel_iterations)\n\u001b[0m\u001b[1;32m 344\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnext_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcurrent_kernel_results\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 345\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/internal/util.py\u001b[0m in \u001b[0;36msmart_for_loop\u001b[0;34m(loop_num_iter, body_fn, initial_loop_vars, parallel_iterations, name)\u001b[0m\n\u001b[1;32m 314\u001b[0m \u001b[0mbody\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbody_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 315\u001b[0m \u001b[0mloop_vars\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mint32\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0minitial_loop_vars\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 316\u001b[0;31m \u001b[0mparallel_iterations\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparallel_iterations\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 317\u001b[0m )[1:]\n\u001b[1;32m 318\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minitial_loop_vars\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py\u001b[0m in \u001b[0;36mnew_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 572\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__module__\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'in a future version'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 573\u001b[0m if date is None else ('after %s' % date), instructions)\n\u001b[0;32m--> 574\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 575\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 576\u001b[0m doc = _add_deprecated_arg_value_notice_to_docstring(\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py\u001b[0m in \u001b[0;36mwhile_loop_v2\u001b[0;34m(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)\u001b[0m\n\u001b[1;32m 2489\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2490\u001b[0m \u001b[0mmaximum_iterations\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmaximum_iterations\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2491\u001b[0;31m return_same_structure=True)\n\u001b[0m\u001b[1;32m 2492\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2493\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py\u001b[0m in \u001b[0;36mwhile_loop\u001b[0;34m(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)\u001b[0m\n\u001b[1;32m 2725\u001b[0m list(loop_vars))\n\u001b[1;32m 2726\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mcond\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mloop_vars\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2727\u001b[0;31m \u001b[0mloop_vars\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbody\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mloop_vars\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2728\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtry_to_pack\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloop_vars\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_basetuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2729\u001b[0m \u001b[0mpacked\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/internal/util.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(i, *args)\u001b[0m\n\u001b[1;32m 312\u001b[0m return tf.while_loop(\n\u001b[1;32m 313\u001b[0m \u001b[0mcond\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mloop_num_iter\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 314\u001b[0;31m \u001b[0mbody\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbody_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 315\u001b[0m \u001b[0mloop_vars\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mint32\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0minitial_loop_vars\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[0mparallel_iterations\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparallel_iterations\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/dual_averaging_step_size_adaptation.py\u001b[0m in \u001b[0;36mone_step\u001b[0;34m(self, current_state, previous_kernel_results)\u001b[0m\n\u001b[1;32m 436\u001b[0m \u001b[0;31m# Step the inner kernel.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 437\u001b[0m new_state, new_inner_results = self.inner_kernel.one_step(\n\u001b[0;32m--> 438\u001b[0;31m current_state, inner_results)\n\u001b[0m\u001b[1;32m 439\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 440\u001b[0m \u001b[0;31m# Get the new step size.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/transformed_kernel.py\u001b[0m in \u001b[0;36mone_step\u001b[0;34m(self, current_state, previous_kernel_results)\u001b[0m\n\u001b[1;32m 321\u001b[0m transformed_next_state, kernel_results = self._inner_kernel.one_step(\n\u001b[1;32m 322\u001b[0m \u001b[0mprevious_kernel_results\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransformed_state\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 323\u001b[0;31m previous_kernel_results.inner_results)\n\u001b[0m\u001b[1;32m 324\u001b[0m transformed_next_state_parts = (\n\u001b[1;32m 325\u001b[0m \u001b[0mtransformed_next_state\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/nuts.py\u001b[0m in \u001b[0;36mone_step\u001b[0;34m(self, current_state, previous_kernel_results)\u001b[0m\n\u001b[1;32m 373\u001b[0m \u001b[0minitial_step_state\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 374\u001b[0m initial_step_metastate),\n\u001b[0;32m--> 375\u001b[0;31m \u001b[0mparallel_iterations\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparallel_iterations\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 376\u001b[0m )\n\u001b[1;32m 377\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py\u001b[0m in \u001b[0;36mnew_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 572\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__module__\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'in a future version'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 573\u001b[0m if date is None else ('after %s' % date), instructions)\n\u001b[0;32m--> 574\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 575\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 576\u001b[0m doc = _add_deprecated_arg_value_notice_to_docstring(\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py\u001b[0m in \u001b[0;36mwhile_loop_v2\u001b[0;34m(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)\u001b[0m\n\u001b[1;32m 2489\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2490\u001b[0m \u001b[0mmaximum_iterations\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmaximum_iterations\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2491\u001b[0;31m return_same_structure=True)\n\u001b[0m\u001b[1;32m 2492\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2493\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py\u001b[0m in \u001b[0;36mwhile_loop\u001b[0;34m(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)\u001b[0m\n\u001b[1;32m 2725\u001b[0m list(loop_vars))\n\u001b[1;32m 2726\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mcond\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mloop_vars\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2727\u001b[0;31m \u001b[0mloop_vars\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbody\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mloop_vars\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2728\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtry_to_pack\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloop_vars\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_basetuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2729\u001b[0m \u001b[0mpacked\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/nuts.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(iter_, state, metastate)\u001b[0m\n\u001b[1;32m 368\u001b[0m \u001b[0miter_\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 369\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 370\u001b[0;31m metastate),\n\u001b[0m\u001b[1;32m 371\u001b[0m loop_vars=(\n\u001b[1;32m 372\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mint32\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'iter'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/nuts.py\u001b[0m in \u001b[0;36mloop_tree_doubling\u001b[0;34m(self, step_size, momentum_state_memory, current_step_meta_info, iter_, initial_step_state, initial_step_metastate)\u001b[0m\n\u001b[1;32m 539\u001b[0m \u001b[0minitial_step_metastate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontinue_tree\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 540\u001b[0m \u001b[0minitial_step_metastate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnot_divergence\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 541\u001b[0;31m momentum_state_memory)\n\u001b[0m\u001b[1;32m 542\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 543\u001b[0m \u001b[0mlast_candidate_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minitial_step_metastate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcandidate_state\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/nuts.py\u001b[0m in \u001b[0;36m_build_sub_tree\u001b[0;34m(self, directions, integrator, current_step_meta_info, nsteps, initial_state, continue_tree, not_divergence, momentum_state_memory, name)\u001b[0m\n\u001b[1;32m 731\u001b[0m \u001b[0mmomentum_state_memory\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 732\u001b[0m ),\n\u001b[0;32m--> 733\u001b[0;31m \u001b[0mparallel_iterations\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparallel_iterations\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 734\u001b[0m )\n\u001b[1;32m 735\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py\u001b[0m in \u001b[0;36mnew_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 572\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__module__\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'in a future version'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 573\u001b[0m if date is None else ('after %s' % date), instructions)\n\u001b[0;32m--> 574\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 575\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 576\u001b[0m doc = _add_deprecated_arg_value_notice_to_docstring(\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py\u001b[0m in \u001b[0;36mwhile_loop_v2\u001b[0;34m(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)\u001b[0m\n\u001b[1;32m 2489\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2490\u001b[0m \u001b[0mmaximum_iterations\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmaximum_iterations\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2491\u001b[0;31m return_same_structure=True)\n\u001b[0m\u001b[1;32m 2492\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2493\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py\u001b[0m in \u001b[0;36mwhile_loop\u001b[0;34m(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)\u001b[0m\n\u001b[1;32m 2725\u001b[0m list(loop_vars))\n\u001b[1;32m 2726\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mcond\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mloop_vars\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2727\u001b[0;31m \u001b[0mloop_vars\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbody\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mloop_vars\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2728\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtry_to_pack\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloop_vars\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_basetuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2729\u001b[0m \u001b[0mpacked\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/nuts.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(iter_, energy_diff_sum, init_momentum_cumsum, leapfrogs_taken, state, state_c, continue_tree, not_divergence, momentum_state_memory)\u001b[0m\n\u001b[1;32m 719\u001b[0m \u001b[0miter_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menergy_diff_sum\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minit_momentum_cumsum\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 720\u001b[0m \u001b[0mleapfrogs_taken\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate_c\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontinue_tree\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 721\u001b[0;31m not_divergence, momentum_state_memory)),\n\u001b[0m\u001b[1;32m 722\u001b[0m loop_vars=(\n\u001b[1;32m 723\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mint32\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'iter'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/nuts.py\u001b[0m in \u001b[0;36m_loop_build_sub_tree\u001b[0;34m(self, directions, integrator, current_step_meta_info, iter_, energy_diff_sum_previous, momentum_cumsum_previous, leapfrogs_taken, prev_tree_state, candidate_tree_state, continue_tree_previous, not_divergent_previous, momentum_state_memory)\u001b[0m\n\u001b[1;32m 811\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor_scatter_nd_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mold\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mwrite_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mnew\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 812\u001b[0m for old, new in zip(momentum_state_memory.momentum_swap,\n\u001b[0;32m--> 813\u001b[0;31m next_momentum_parts)\n\u001b[0m\u001b[1;32m 814\u001b[0m ],\n\u001b[1;32m 815\u001b[0m state_swap=[\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow_probability/python/mcmc/nuts.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 810\u001b[0m momentum_swap=[\n\u001b[1;32m 811\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor_scatter_nd_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mold\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mwrite_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mnew\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 812\u001b[0;31m for old, new in zip(momentum_state_memory.momentum_swap,\n\u001b[0m\u001b[1;32m 813\u001b[0m next_momentum_parts)\n\u001b[1;32m 814\u001b[0m ],\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py\u001b[0m in \u001b[0;36mtensor_scatter_update\u001b[0;34m(tensor, indices, updates, name)\u001b[0m\n\u001b[1;32m 10810\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10811\u001b[0m return tensor_scatter_update_eager_fallback(\n\u001b[0;32m> 10812\u001b[0;31m tensor, indices, updates, name=name, ctx=_ctx)\n\u001b[0m\u001b[1;32m 10813\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0m_core\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_SymbolicException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10814\u001b[0m \u001b[0;32mpass\u001b[0m \u001b[0;31m# Add nodes to the TensorFlow graph.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py\u001b[0m in \u001b[0;36mtensor_scatter_update_eager_fallback\u001b[0;34m(tensor, indices, updates, name, ctx)\u001b[0m\n\u001b[1;32m 10854\u001b[0m \u001b[0m_attrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m\"T\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_attr_T\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"Tindices\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_attr_Tindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10855\u001b[0m _result = _execute.execute(b\"TensorScatterUpdate\", 1, inputs=_inputs_flat,\n\u001b[0;32m> 10856\u001b[0;31m attrs=_attrs, ctx=ctx, name=name)\n\u001b[0m\u001b[1;32m 10857\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_execute\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmust_record_gradient\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10858\u001b[0m _execute.record_gradient(\n",
"\u001b[0;32m/home/local/cs/local/install/anaconda3-2020.02-Linux-x86_64/envs/py36ds/lib/python3.6/site-packages/tensorflow/python/eager/execute.py\u001b[0m in \u001b[0;36mquick_execute\u001b[0;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mensure_initialized\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,\n\u001b[0;32m---> 60\u001b[0;31m inputs, attrs, num_outputs)\n\u001b[0m\u001b[1;32m 61\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_NotOkStatusException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mInvalidArgumentError\u001b[0m: The inner 2 dimensions of output.shape=[11,10,1] must match the inner 2 dimensions of updates.shape=[1,10,10000] [Op:TensorScatterUpdate]"
]
}
],
"source": [
"res = run_nuts(target_log_prob_fn_problematic, [tf.ones((n_chains, 1)), tf.ones((n_chains, 1))])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment