Skip to content

Instantly share code, notes, and snippets.

@xiangze
Created July 17, 2018 16:16
Show Gist options
  • Save xiangze/e15ac9fbe4bffa49347b9276fd44d4f2 to your computer and use it in GitHub Desktop.
Save xiangze/e15ac9fbe4bffa49347b9276fd44d4f2 to your computer and use it in GitHub Desktop.
hierachicalmodel_edward
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Edwardでの階層モデル"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- https://discourse.edwardlib.org/t/simple-hierarchical-model-fails/196\n",
"- http://willwolf.io/2017/06/15/random-effects-neural-networks/ (original question)\n",
"\n",
"- https://aksarkar.github.io/nwas/klqp.html (answer)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/Cellar/python3/3.6.3/Frameworks/Python.framework/Versions/3.6/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6\n",
" return f(*args, **kwds)\n",
"/Users/apple/Library/Python/3.6/lib/python/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n"
]
}
],
"source": [
"import edward as ed\n",
"from edward.models import Normal\n",
"import numpy as np\n",
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# TOY DATA\n",
"N = 3 # number of groups\n",
"M = 1000 # samples per group\n",
"\n",
"# mean for each group is different\n",
"# want to infer the group means plus the overall mean\n",
"actual_group_means = [0.1, 0.2, 0.3]\n",
"sigma = 0.1\n",
"\n",
"observed_groups = np.repeat([0, 1, 2], M)\n",
"samples = [np.random.normal(actual_group_means[g], sigma, M) for g in range(N)]\n",
"observed_data = np.concatenate(samples)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MODEL"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"groups = tf.placeholder(tf.int32, [M * N])\n",
"\n",
"overall_mean = Normal(\n",
" loc=tf.zeros(1), \n",
" scale=tf.ones(1) * 0.05 )\n",
"\n",
"group_means = Normal(\n",
" loc=tf.ones(N) * overall_mean,\n",
" scale=tf.ones(N) * 0.05 )\n",
"\n",
"q_overall_mean = Normal(\n",
" loc=tf.Variable(tf.zeros(1)),\n",
" scale=tf.nn.softplus(tf.Variable(tf.zeros(1))) )\n",
"\n",
"q_group_means = Normal(\n",
" loc=tf.Variable(tf.zeros(N)),\n",
" scale=tf.nn.softplus(tf.Variable(tf.zeros(N))) )\n",
"\n",
"data = Normal(\n",
" loc=tf.gather(group_means, groups),\n",
" scale=tf.ones(shape=[N * M]) * sigma )\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/site-packages/edward/util/random_variables.py:52: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" not np.issubdtype(value.dtype, np.float) and \\\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1000/1000 [100%] ██████████████████████████████ Elapsed: 7s | Loss: -2533.115\n",
"Using <class 'edward.inferences.klqp.ReparameterizationKLqp'>:\n",
"[0.13311534]\n",
"[0.0944147 0.20102091 0.2979718 ]\n",
"1000/1000 [100%] ██████████████████████████████ Elapsed: 4s | Loss: -2493.819\n",
"Using <class 'edward.inferences.klqp.ReparameterizationKLKLqp'>:\n",
"[0.]\n",
"[0.09101488 0.19820164 0.2882782 ]\n"
]
}
],
"source": [
"#for inference_alg in (ed.KLpq, ed.KLqp):\n",
"for inference_alg in (ed.ReparameterizationKLqp, ed.ReparameterizationKLKLqp):\n",
" inference = inference_alg(\n",
" {\n",
" overall_mean: q_overall_mean,\n",
" group_means: q_group_means\n",
" },\n",
" data={\n",
" groups: observed_groups,\n",
" data: observed_data\n",
" }\n",
" )\n",
" \n",
" inference.run(n_samples=5, n_iter=1000)\n",
" sess = ed.get_session()\n",
" print('Using {}:'.format(inference_alg))\n",
" print(q_overall_mean.mean().eval())\n",
" print(q_group_means.mean().eval())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment