Last active
September 14, 2017 03:12
-
-
Save aseyboldt/1054cf6d6b871041914c601c1efa11ae to your computer and use it in GitHub Desktop.
tensorflow-pymc3-experiment
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"import numpy as np\n", | |
"import collections" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# TODO We can use `tf.distributions` for the most part,\n", | |
"# but we might want to new ones, and maybe also add methods\n", | |
"# to them. Using raw `tf.distributions` objects for now.\n", | |
"class Distribution:\n", | |
" pass\n", | |
"\n", | |
"\n", | |
"class RandomVariable:\n", | |
" def __init__(self, name, dist, transform, dtype, shape):\n", | |
" if shape is None:\n", | |
" raise ValueError()\n", | |
" self._name = name\n", | |
" self._dtype = dtype\n", | |
" self._shape = shape\n", | |
" self._transform = transform\n", | |
" self._dist = dist\n", | |
" # check domain of dist vs transform.onto\n", | |
"\n", | |
" @property\n", | |
" def name(self):\n", | |
" return self._name\n", | |
" \n", | |
" @property\n", | |
" def dtype(self):\n", | |
" return self._dtype\n", | |
" \n", | |
" @property\n", | |
" def shape(self):\n", | |
" return self._shape\n", | |
" \n", | |
" def _logp(self, value, *, jacterms):\n", | |
" # TODO Pass transform to dists?\n", | |
" # That would allow specialized implementations.\n", | |
" transform = self._transform\n", | |
" if transform is not None:\n", | |
" value = transform.forward(value)\n", | |
" \n", | |
" logp = self._dist.log_prob(value)\n", | |
" if transform is not None and jacterms:\n", | |
" logp = logp + transform.jacdet(value)\n", | |
" return logp\n", | |
"\n", | |
" return self._dist.logp(self.var, self.trafo)\n", | |
" \n", | |
" def _logp_sum(self, value, *, jacterms):\n", | |
" # TODO give the distributions a chance to do this\n", | |
" # on their own.\n", | |
" return tf.reduce_sum(self._logp(value, jacterms=jacterms))\n", | |
"\n", | |
"\n", | |
"class Free(RandomVariable):\n", | |
" def __init__(self, name, dist, transform, dtype, shape):\n", | |
" super().__init__(name, dist, transform, dtype, shape)\n", | |
" self._var = tf.placeholder(dtype, shape, name)\n", | |
" \n", | |
" @property\n", | |
" def var(self):\n", | |
" return self._var\n", | |
" \n", | |
" def logp(self, *, jacterms=True):\n", | |
" return self._logp(self.var, jacterms=jacterms)\n", | |
" \n", | |
" def logp_sum(self, *, jacterms=True):\n", | |
" return self._logp_sum(self.var, jacterms=jacterms)\n", | |
"\n", | |
" \n", | |
"class Observed(RandomVariable):\n", | |
" def __init__(self, name, dist, transform, dtype, shape, observed):\n", | |
" super().__init__(name, dist, transform, dtype, shape)\n", | |
" if not isinstance(observed, Data):\n", | |
" observed = tf.constant(observed)\n", | |
" self._observed = observed\n", | |
" \n", | |
" @property\n", | |
" def observed(self):\n", | |
" return self._observed\n", | |
" \n", | |
" def logp(self, *, jacterms=True):\n", | |
" return self._logp(self.observed, jacterms=jacterms)\n", | |
"\n", | |
" def logp_sum(self, *, jacterms=True):\n", | |
" return self._logp_sum(self.observed, jacterms=jacterms)\n", | |
"\n", | |
"\n", | |
"class Data:\n", | |
" def __init__(self, name, dtype=None, shape=None, default=None):\n", | |
" # TODO\n", | |
" self._var = tf.Variable(dtype, shape, name)\n", | |
"\n", | |
"\n", | |
"class TfModel:\n", | |
" _context_stack = []\n", | |
"\n", | |
" def __init__(self):\n", | |
" self._free_vars = []\n", | |
" self._observed_vars = []\n", | |
" self._graph = tf.Graph()\n", | |
" self._graph_context = None\n", | |
" \n", | |
" def add_free_var(self, var):\n", | |
" self._free_vars.append(var)\n", | |
" \n", | |
" def add_observed_var(self, var):\n", | |
" self._observed_vars.append(var)\n", | |
" \n", | |
" def __enter__(self):\n", | |
" TfModel._context_stack.append(self)\n", | |
" self._graph_context = self._graph.as_default()\n", | |
" self._graph_context.__enter__()\n", | |
" return self\n", | |
" \n", | |
" def __exit__(self, *args, **kwargs):\n", | |
" old = TfModel._context_stack.pop()\n", | |
" self._graph_context.__exit__(*args, **kwargs)\n", | |
" assert old is self\n", | |
" \n", | |
" def _logp_sum(self, *, jacterms=True, reduce=True):\n", | |
" with self:\n", | |
" vars = self._free_vars + self._observed_vars\n", | |
" logp_free = [var.logp_sum(jacterms=jacterms)\n", | |
" for var in vars]\n", | |
" # TODO optional reduce?\n", | |
" return tf.reduce_sum(logp_free, name='logp__')\n", | |
" \n", | |
" def logp_function(self, *, jacterms=True, session=None):\n", | |
" raise NotImplementedError()\n", | |
" \n", | |
" def logp_dlogp_function(self, grad_vars=None, *, target=None, dtype=None,\n", | |
" jacterms=True, session_config=None, data=None):\n", | |
" if grad_vars is None:\n", | |
" # TODO check dtype\n", | |
" grad_vars = self._free_vars.copy()\n", | |
" cost = self._logp_sum(jacterms=jacterms)\n", | |
" return ValueGradFunction(cost, grad_vars, None, target=target,\n", | |
" dtype=dtype,\n", | |
" data=data, session_config=session_config)\n", | |
"\n", | |
"\n", | |
"def model_from_context(model):\n", | |
" if model is not None:\n", | |
" return model\n", | |
" if len(TfModel._context_stack) == 0:\n", | |
" raise ValueError('No model on context stack.')\n", | |
" return TfModel._context_stack[-1]\n", | |
"\n", | |
"\n", | |
"VarMap = collections.namedtuple(\n", | |
" 'VarMap', 'name, slice, shape, dtype')\n", | |
"\n", | |
"\n", | |
"class ArrayOrdering:\n", | |
" def __init__(self, vars, dtype, *, casting='no', data=None, session=None):\n", | |
" maps = []\n", | |
" total = 0\n", | |
" if session is None:\n", | |
" session = tf.get_default_session()\n", | |
" \n", | |
" for var in vars:\n", | |
" # TODO check casting\n", | |
" if var.dtype != dtype:\n", | |
" raise ValueError()\n", | |
" name = var.name\n", | |
" shape = var.shape\n", | |
" if shape is None or any(dim is None for dim in shape):\n", | |
" # TODO no session\n", | |
" shape = session.run(shape, data)\n", | |
" size = int(np.prod(shape))\n", | |
" slice_ = slice(total, total + size)\n", | |
" maps.append(VarMap(name, slice_, shape, var.dtype))\n", | |
" total += size\n", | |
" self.size = total\n", | |
" self.vmap = maps\n", | |
" self.dtype = dtype\n", | |
" \n", | |
" def array_to_dict(self, array):\n", | |
" vars = {}\n", | |
" for name, slice_, shape, dtype in slef.vmap:\n", | |
" # TODO check casting\n", | |
" vars[name] = array[slice_].astype(dtype)\n", | |
" return vars\n", | |
"\n", | |
" def dict_to_array(self, vars):\n", | |
" array = np.empty(self.size, dtype=self.dtype)\n", | |
" for var in vars:\n", | |
" data = vars[var.name]\n", | |
" data = np.asarray(data, order='C')\n", | |
" # TODO check casting\n", | |
" array[var.slice] = data\n", | |
" return array\n", | |
"\n", | |
"\n", | |
"class ValueGradFunction:\n", | |
" def __init__(self, cost, grad_vars, extra_vars=None, target=None,\n", | |
" dtype=None, casting='no', session_config=None, data=None):\n", | |
" if extra_vars is None:\n", | |
" extra_vars = []\n", | |
" \n", | |
" old_graph = cost.graph\n", | |
" cost_name = cost.name\n", | |
" grad_names = [var.name for var in grad_vars]\n", | |
" # TODO\n", | |
" assert not extra_vars\n", | |
" self._ordering = ArrayOrdering(grad_vars, dtype,\n", | |
" casting=casting, data=data)\n", | |
" \n", | |
" graph_def = old_graph.as_graph_def()\n", | |
" graph = tf.Graph()\n", | |
" with graph.as_default():\n", | |
" array = tf.placeholder(dtype, (self._ordering.size,), name='freeRV_array_')\n", | |
" self._array = array\n", | |
" var_slices = {}\n", | |
" for var in self._ordering.vmap:\n", | |
" var_slices[var.name] = tf.reshape(array[var.slice], var.shape)\n", | |
" cost_array, = tf.import_graph_def(graph_def, var_slices, [cost_name])\n", | |
" #cost_array = cost_array.outputs[0]\n", | |
" cost_grad = tf.gradients(cost_array, array)\n", | |
" sess = tf.Session(target=target, graph=graph, config=session_config)\n", | |
" self._session = sess\n", | |
" self._cost_array = cost_array\n", | |
" self._cost_grad = cost_grad\n", | |
" self._graph = graph\n", | |
" \n", | |
" def __call__(self, array):\n", | |
" return self._session.run([self._cost_array, self._cost_grad],\n", | |
" {self._array: array})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# TODO find a nice way to do this semi-automatically for\n", | |
"# different distributions in `tf.contrib.distributions`.\n", | |
"def Normal(name, mu, sigma, shape=None, dtype=None, transform=None, observed=None, model=None):\n", | |
" dist = tf.contrib.distributions.Normal(mu, sigma, name=name + '_dist__')\n", | |
" if shape is None:\n", | |
" shape = dist.batch_shape\n", | |
" if dtype is None:\n", | |
" dtype = tf.float32\n", | |
" model = model_from_context(model)\n", | |
"\n", | |
" if observed is None:\n", | |
" var = Free(name, dist, transform, dtype, shape)\n", | |
" model.add_free_var(var)\n", | |
" return var.var\n", | |
" else:\n", | |
" var = Observed(name, dist, transform, dtype, shape, observed)\n", | |
" model.add_observed_var(var)\n", | |
" return var.observed" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = TfModel()\n", | |
"\n", | |
"with model:\n", | |
" a = Normal('a', 0., 10.)\n", | |
" c = Normal('c', 0., 5.)\n", | |
" b = Normal('b', a, 1., observed=[3., 4., 3.5])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"logp_dlogp = model.logp_dlogp_function(dtype='float32')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"833 µs ± 45.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%timeit logp_dlogp(np.zeros(2))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"800 µs ± 33.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"logp = model._logp_sum()\n", | |
"logp_grad = tf.gradients(logp, [var.var for var in model._free_vars])\n", | |
"\n", | |
"sess = tf.Session(graph=model._graph)\n", | |
"a_ = np.array(3.)\n", | |
"c_ = np.array(4.)\n", | |
"%timeit sess.run([logp, logp_grad], {a: a_, c: c_})" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment