Last active
May 5, 2019 23:18
-
-
Save ferrine/a4ffe58cb7b5469f5afa64ad6888f49a to your computer and use it in GitHub Desktop.
pymc4 playground
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 141, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import itertools\n", | |
"import scipy.stats\n", | |
"import collections\n", | |
"import functools\n", | |
"import types\n", | |
"Event = collections.namedtuple(\"Event\", \"sender,info\")\n", | |
"\n", | |
"class DebugPrint(Event):\n", | |
" \"\"\"Just a plasholder, whatever\"\"\"\n", | |
"\n", | |
"def notify(fn):\n", | |
" \"\"\"Some distributions may be reparametrizable, we may want to hook the behaviour\"\"\"\n", | |
" @functools.wraps(fn)\n", | |
" def wrapped(*args, **kwargs):\n", | |
" action = yield Event(fn, dict(args=args, kwargs=kwargs))\n", | |
" # do nothing, may be changed to\n", | |
" # value = yield from action(fn, *args, **kwargs)\n", | |
" # or\n", | |
" # value = yield from action(fn)(*args, **kwargs)\n", | |
" value = yield from fn(*args, **kwargs)\n", | |
" return value\n", | |
" return wrapped\n", | |
"\n", | |
"@notify\n", | |
"def horseshoe_model(tau):\n", | |
" # nested model.\n", | |
" # use philosophy \"model is a distribution!\"\n", | |
" lamb = yield \"lamb\", scipy.stats.halfcauchy(0, 1)\n", | |
" norm = yield \"norm\", scipy.stats.norm(0, lamb*tau)\n", | |
"\n", | |
" # uncomment to to debug print or whatever\n", | |
" #yield DebugPrint(\"user\", locals())\n", | |
" # raises error message\n", | |
" yield another(1)\n", | |
" return norm\n", | |
"\n", | |
"def another(ha):\n", | |
" yield \"lamb\", scipy.stats.halfcauchy(0, 1)\n", | |
" # an attempt to check g.gi_yieldfrom.gi_frame.f_locals\n", | |
" # and compare it with debug print\n", | |
" yield DebugPrint(\"user\", locals())\n", | |
"\n", | |
"def model():\n", | |
" # user defined model\n", | |
" # some distributions are Atomic and not models themselves\n", | |
" tau = yield \"tau\", scipy.stats.expon(1)\n", | |
" # soem are models and should be called with yield from syntax\n", | |
" hs = yield from horseshoe_model(tau)\n", | |
" # but what to do with Normal(rv, rv2) reparametrized?\n", | |
" # wrap them with JointDistributionCoroutine or what??\n", | |
" return {\"tau\": tau, \"hs\": hs}\n", | |
"\n", | |
"\n", | |
"def sample_prior(m):\n", | |
" g = m()\n", | |
" rv = None\n", | |
" while True:\n", | |
" try:\n", | |
" print(\"g.gi_yieldfrom\", g.gi_yieldfrom)\n", | |
" val = g.send(rv)\n", | |
" if isinstance(val, types.GeneratorType):\n", | |
" g.throw(RuntimeError(\"Tried to yield a generator, but not a distribution or Event, change `yield` to `yield from` \"))\n", | |
" if isinstance(val, Event):\n", | |
" print(\"got an event:\", val)\n", | |
" if isinstance(val, DebugPrint):\n", | |
" if g.gi_yieldfrom is None:\n", | |
" locs = g.gi_frame.f_locals\n", | |
" else:\n", | |
" locs = g.gi_yieldfrom.gi_frame.f_locals\n", | |
" print(\"You can also get g.gi_frame.f_locals\", locs)\n", | |
" else:\n", | |
" name, dist = val\n", | |
" print(\"recieving:\", name)\n", | |
" rv = dist.rvs()\n", | |
" print(\"sampled: \", rv)\n", | |
" except StopIteration as e:\n", | |
" print(\"interesting:\", e.args[0])\n", | |
" break" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 142, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"g.gi_yieldfrom None\n", | |
"recieving: tau\n", | |
"sampled: 1.5203764466999203\n", | |
"g.gi_yieldfrom None\n", | |
"got an event: Event(sender=<function horseshoe_model at 0x1157ed9d8>, info={'args': (1.5203764466999203,), 'kwargs': {}})\n", | |
"g.gi_yieldfrom <generator object horseshoe_model at 0x1157a5830>\n", | |
"recieving: lamb\n", | |
"sampled: 1.3797586508781439\n", | |
"g.gi_yieldfrom <generator object horseshoe_model at 0x1157a5830>\n", | |
"recieving: norm\n", | |
"sampled: 1.6091662050221651\n", | |
"g.gi_yieldfrom <generator object horseshoe_model at 0x1157a5830>\n", | |
"got an event: DebugPrint(sender='user', info={'norm': 1.6091662050221651, 'lamb': 1.3797586508781439, 'tau': 1.5203764466999203})\n", | |
"You can also get g.gi_frame.f_locals {'action': 1.5203764466999203, 'kwargs': {}, 'args': (1.5203764466999203,), 'fn': <function horseshoe_model at 0x1157ed9d8>}\n", | |
"g.gi_yieldfrom <generator object horseshoe_model at 0x1157a5830>\n" | |
] | |
}, | |
{ | |
"ename": "RuntimeError", | |
"evalue": "Tried to yield a generator, but not a distribution or Event, change `yield` to `yield from` ", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-142-1805250b2b54>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msample_prior\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m<ipython-input-141-d5f0c3c45342>\u001b[0m in \u001b[0;36msample_prior\u001b[0;34m(m)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0mval\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGeneratorType\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mthrow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Tried to yield a generator, but not a distribution or Event, change `yield` to `yield from` \"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 45\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mEvent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"got an event:\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-141-d5f0c3c45342>\u001b[0m in \u001b[0;36mmodel\u001b[0;34m()\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0mtau\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0;34m\"tau\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscipy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstats\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexpon\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0mhs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mhorseshoe_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtau\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m\"tau\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtau\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"hs\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mhs\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-141-d5f0c3c45342>\u001b[0m in \u001b[0;36mwrapped\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0maction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0mEvent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapped\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-141-d5f0c3c45342>\u001b[0m in \u001b[0;36mhorseshoe_model\u001b[0;34m(tau)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0mDebugPrint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"user\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlocals\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0manother\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnorm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mRuntimeError\u001b[0m: Tried to yield a generator, but not a distribution or Event, change `yield` to `yield from` " | |
] | |
} | |
], | |
"source": [ | |
"g = sample_prior(model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 118, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"('tau', <scipy.stats._distn_infrastructure.rv_frozen at 0x1157d0940>)" | |
] | |
}, | |
"execution_count": 118, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"g = model()\n", | |
"next(g)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "pymc4", | |
"language": "python", | |
"name": "pymc4" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment