Skip to content

Instantly share code, notes, and snippets.

@fehiepsi
Last active April 6, 2019 02:19
Show Gist options
  • Save fehiepsi/cd8409d0a3feb3e3f287b12b9960d13c to your computer and use it in GitHub Desktop.
Save fehiepsi/cd8409d0a3feb3e3f287b12b9960d13c to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import jax.numpy as np\n",
"from jax import jit, lax, random\n",
"from numpyro.hmc_util import build_tree, velocity_verlet"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def kinetic_fn(m_inv, p):\n",
" return 0.5 * np.dot(m_inv, p ** 2)\n",
"\n",
"def potential_fn(q):\n",
" return 0.5 * np.sum(q ** 2)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/fehiepsi/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/lib/xla_bridge.py:144: UserWarning: No GPU/TPU found, falling back to CPU.\n",
" warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
]
}
],
"source": [
"vv_init, vv_update = velocity_verlet(potential_fn, kinetic_fn)\n",
"step_size = 0.001 # force build full tree ~ 1000 leaves\n",
"rng = random.PRNGKey(0)\n",
"make_test = lambda dim: (vv_init(np.zeros(dim), np.ones(dim)), np.ones(dim))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"%load_ext memory_profiler"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"vv_state, inverse_mass_matrix = make_test(10000)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4.18 s, sys: 23.3 ms, total: 4.2 s\n",
"Wall time: 4.19 s\n",
"136 ms ± 1.8 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"@jit\n",
"def f(vv_state):\n",
" return build_tree(vv_update, kinetic_fn, vv_state,\n",
" inverse_mass_matrix, step_size, rng, iterative_build=True)\n",
"\n",
"%time f(vv_state)\n",
"%timeit f(vv_state)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def g(vv_state):\n",
" return build_tree(vv_update, kinetic_fn, vv_state,\n",
" inverse_mass_matrix, step_size, rng, iterative_build=False)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"vv_state = vv_update(step_size, inverse_mass_matrix, vv_state)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"peak memory: 199.18 MiB, increment: 0.09 MiB\n"
]
}
],
"source": [
"%%memit\n",
"f(vv_state)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"peak memory: 200.36 MiB, increment: 1.18 MiB\n"
]
}
],
"source": [
"%%memit\n",
"g(vv_state)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment