Last active
April 13, 2021 20:27
-
-
Save twiecki/a77104299535b64b58953de3c84df56f to your computer and use it in GitHub Desktop.
stochastic_volatility.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": { | |
"colab_type": "text", | |
"id": "view-in-github" | |
}, | |
"cell_type": "markdown", | |
"source": "<a href=\"https://colab.research.google.com/gist/junpenglao/c8b884797f950d1ef033ca69b253a4a0/stochastic_volatility_jax.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
}, | |
{ | |
"metadata": { | |
"id": "QNyI1zUJkAKw" | |
}, | |
"cell_type": "markdown", | |
"source": "# Stochastic Volatility model" | |
}, | |
{ | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "lGnXuq-QkAK1", | |
"outputId": "9caf1d60-7859-4f37-e01f-9906969fb7e6", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "import matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\nimport pymc3 as pm\nimport pymc3.sampling_jax\nimport arviz as az\n\n%matplotlib inline\n\nnp.random.seed(0)", | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": "/Users/twiecki/projects/pymc/pymc3/sampling_jax.py:24: UserWarning: This module is experimental.\n warnings.warn(\"This module is experimental.\")\n" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 235 | |
}, | |
"id": "e5QZXr0ckAK3", | |
"outputId": "1cfe592d-850f-4c3b-a8e3-e05dfb26cd9c", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "returns = pd.read_csv(pm.get_data(\"SP500.csv\"), index_col=\"Date\")\nreturns[\"change\"] = np.log(returns[\"Close\"]).diff()\nreturns = returns.dropna()\nreturns.head()", | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>Close</th>\n <th>change</th>\n </tr>\n <tr>\n <th>Date</th>\n <th></th>\n <th></th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>2008-05-05</th>\n <td>1407.489990</td>\n <td>-0.004544</td>\n </tr>\n <tr>\n <th>2008-05-06</th>\n <td>1418.260010</td>\n <td>0.007623</td>\n </tr>\n <tr>\n <th>2008-05-07</th>\n <td>1392.569946</td>\n <td>-0.018280</td>\n </tr>\n <tr>\n <th>2008-05-08</th>\n <td>1397.680054</td>\n <td>0.003663</td>\n </tr>\n <tr>\n <th>2008-05-09</th>\n <td>1388.280029</td>\n <td>-0.006748</td>\n </tr>\n </tbody>\n</table>\n</div>", | |
"text/plain": " Close change\nDate \n2008-05-05 1407.489990 -0.004544\n2008-05-06 1418.260010 0.007623\n2008-05-07 1392.569946 -0.018280\n2008-05-08 1397.680054 0.003663\n2008-05-09 1388.280029 -0.006748" | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "gMYC4zx9kAK3", | |
"outputId": "726bcb51-4971-4643-9c56-f264a099ca8f", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "with pm.Model(check_bounds=False) as model:\n step_size = pm.Exponential(\"step_size\", 10)\n volatility = pm.GaussianRandomWalk(\"volatility\", sigma=step_size, \n shape=returns.shape[0], \n init=pm.Normal.dist(0, step_size))\n nu = pm.Exponential(\"nu\", 0.1)\n obs = pm.StudentT(\n \"obs\", nu=nu, sigma=np.exp(volatility), observed=returns[\"change\"]\n )", | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": "/Users/twiecki/projects/pymc/pymc3/distributions/continuous.py:138: UserWarning: The variable specified for nu has negative support for StudentT, likely making it unsuitable for this parameter.\n warnings.warn(msg)\n" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "RVKPflYjkAK4", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "# %%time\n# with model:\n# trace = pm.sampling_jax.sample_numpyro_nuts(2000, tune=2000)", | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "_-RM-GykkAK4", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "# %%time\n# with model:\n# trace = pm.sampling_jax.sample_tfp_nuts(2000, tune=2000)", | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "x_uEuuvTkAK4", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "import os\nimport jax\n\nimport matplotlib\nimport matplotlib.dates as mdates\nimport matplotlib.pyplot as plt\n\nimport jax.numpy as jnp\nimport jax.random as random\n\nimport numpyro\nimport numpyro.distributions as dist\nfrom numpyro.examples.datasets import SP500, load_dataset\nfrom numpyro.infer.hmc import hmc\nfrom numpyro.infer.util import initialize_model\nfrom numpyro.util import fori_collect\nfrom numpyro.infer import MCMC, NUTS\n\nmatplotlib.use('Agg') # noqa: E402\n\n\ndef model_numpyro(returns):\n step_size = numpyro.sample('step_size',\n dist.Exponential(10.))\n volatility = numpyro.sample('volatility',\n dist.GaussianRandomWalk(scale=step_size, num_steps=jnp.shape(returns)[0]))\n nu = numpyro.sample('nu', dist.Exponential(.1))\n return numpyro.sample('r', dist.StudentT(df=nu, loc=0., scale=jnp.exp(volatility)),\n obs=returns)", | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "naiiXZPpkAK5", | |
"outputId": "3de89d6b-daa7-4397-bb25-ebdd1f74cdec", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "init_rng_key, sample_rng_key = random.split(random.PRNGKey(1))\nmodel_info = initialize_model(init_rng_key, model_numpyro, model_args=(returns[\"change\"].values,))\ninit_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS')\nhmc_state = init_kernel(model_info.param_info, 2000, rng_key=sample_rng_key)\nhmc_state.z", | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "{'nu': DeviceArray(1.85527508, dtype=float64),\n 'step_size': DeviceArray(0.86971885, dtype=float64),\n 'volatility': DeviceArray([-0.05782139, -1.0270261 , 1.03265798, ..., 1.56501771,\n -0.18145358, 1.22302515], dtype=float64)}" | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "iMjkPn08kAK5", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "# Run NUTS.\n\n# num_warmup, num_samples = 1000, 2000\n# kernel = NUTS(model_numpyro)\n# mcmc = MCMC(kernel, num_warmup, num_samples)\n# mcmc.run(sample_rng_key, returns=returns[\"change\"].values)\n# mcmc.print_summary()\n# samples_1 = mcmc.get_samples()", | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "5_p-7J2_kAK5", | |
"outputId": "cdcb1006-3658-4d05-ada4-33d1196c96e4", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "fn = jax.jit(model_info.potential_fn)\nprint(fn(hmc_state.z))\n%timeit fn(hmc_state.z).block_until_ready()", | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": "8714.197040735067\n594 µs ± 50.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "HYlcxAn6kAK6", | |
"outputId": "e758fa72-b9e2-4f0e-b834-7a1da4ab4aae", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "from theano.link.jax.jax_dispatch import jax_funcify\nimport theano\n\nfgraph = theano.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])\nfns = jax_funcify(fgraph)\nlogp_fn_jax = fns[0]\n\nrv_names = [rv.name for rv in model.free_RVs]\ninit_state = [model.test_point[rv_name] for rv_name in rv_names]\nfn2 = jax.jit(logp_fn_jax)\nfn2(*init_state)\n%timeit fn2(*init_state).block_until_ready()", | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": "671 µs ± 45.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "IHldqMazkep8", | |
"outputId": "d65bf37a-713e-4700-c896-cc13abb07619", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "fn_with_grad = jax.jit(jax.value_and_grad(model_info.potential_fn))\nfn_with_grad(hmc_state.z)\n%timeit fn(hmc_state.z).block_until_ready()\nfn2_with_grad = jax.jit(jax.value_and_grad(logp_fn_jax))\nfn2_with_grad(*init_state)\n%timeit fn2(*init_state).block_until_ready()", | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": "667 µs ± 90.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n678 µs ± 57.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "TOBIX1HNkAK6", | |
"outputId": "a2672950-f5c8-4716-8e59-cd0c9d9d1053", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "init_state2 = [hmc_state.z['step_size'], hmc_state.z['volatility'], hmc_state.z['nu']]\nfn(hmc_state.z), fn2(*init_state2)", | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "(DeviceArray(8714.19704074, dtype=float64),\n DeviceArray(-8714.19704074, dtype=float64))" | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "189J9pS6kAK6", | |
"outputId": "57cca92b-35de-49e9-d5b6-53d93d558e4b", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "model.test_point", | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "{'step_size_log__': array(-2.66909801),\n 'volatility': array([0., 0., 0., ..., 0., 0., 0.]),\n 'nu_log__': array(1.93607218)}" | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "w_ZlLLOmkAK6", | |
"outputId": "53587241-617d-4e0c-a5ed-ba14fc5b3208", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "z2 = {\n 'step_size': model.test_point['step_size_log__'],\n 'nu': model.test_point['nu_log__'],\n 'volatility': model.test_point['volatility']\n}\nfn(z2), fn2(*init_state)", | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "(DeviceArray(-2307.90263155, dtype=float64),\n DeviceArray(2307.90263155, dtype=float64))" | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "9Pcu4Q2EkAK7", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "draws=2000\ntune=2000\nchains=4\ntarget_accept=0.8\nrandom_seed=10\nprogress_bar=True\nseed = jax.random.PRNGKey(random_seed)", | |
"execution_count": 12, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "CIyw3GLBkAK7", | |
"outputId": "c0307a4a-7d18-4ece-f83e-6a859026dab8", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "%%time\ninit_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), hmc_state.z)\n\[email protected]\ndef _sample(current_state, seed):\n step_size = jax.tree_map(jax.numpy.ones_like, init_state)\n nuts_kernel = NUTS(\n potential_fn=model_info.potential_fn,\n # model=model,\n target_accept_prob=target_accept,\n adapt_step_size=True,\n adapt_mass_matrix=True,\n dense_mass=False,\n )\n\n pmap_numpyro = MCMC(\n nuts_kernel,\n num_warmup=tune,\n num_samples=draws,\n num_chains=chains,\n postprocess_fn=None,\n chain_method=\"parallel\",\n progress_bar=progress_bar,\n )\n\n pmap_numpyro.run(seed, init_params=current_state, extra_fields=(\"num_steps\",))\n samples = pmap_numpyro.get_samples(group_by_chain=True)\n leapfrogs_taken = pmap_numpyro.get_extra_fields(group_by_chain=True)[\"num_steps\"]\n return samples, leapfrogs_taken\n\nprint(\"Compiling...\")\ntic2 = pd.Timestamp.now()\nmap_seed = jax.random.split(seed, chains)\nposterior, leapfrogs_taken = _sample(init_state_batched, map_seed)\nleapfrogs_taken.block_until_ready()\n# map_seed = jax.random.split(seed, chains)\n# mcmc_samples = _sample(init_state_batched, map_seed)\n# tic4 = pd.Timestamp.now()\n# print(\"Sampling time = \", tic4 - tic3)\n\ntic3 = pd.Timestamp.now()\nprint(\"Compilation + sampling time = \", tic3 - tic2)", | |
"execution_count": 20, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": "Compiling...\nCompilation + sampling time = 0 days 00:05:27.456873\nCPU times: user 11min, sys: 16.8 s, total: 11min 17s\nWall time: 5min 27s\n" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "z7u27qErPQ-J", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "az_trace = az.from_dict(posterior=posterior)", | |
"execution_count": 21, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Df2t46y9LbcU", | |
"outputId": "e18b968e-be56-4341-9e26-9b486af96e19", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "%%time\ninit_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), hmc_state.z)\ninit_state_batched_ = [init_state_batched['step_size'], init_state_batched['volatility'], init_state_batched['nu']]\n\[email protected]\ndef _sample(current_state, seed):\n step_size = jax.tree_map(jax.numpy.ones_like, init_state)\n nuts_kernel = NUTS(\n potential_fn=lambda x: -logp_fn_jax(*x),\n # model=model,\n target_accept_prob=target_accept,\n adapt_step_size=True,\n adapt_mass_matrix=True,\n dense_mass=False,\n )\n\n pmap_numpyro = MCMC(\n nuts_kernel,\n num_warmup=tune,\n num_samples=draws,\n num_chains=chains,\n postprocess_fn=None,\n chain_method=\"parallel\",\n progress_bar=progress_bar,\n )\n\n pmap_numpyro.run(seed, init_params=current_state, extra_fields=(\"num_steps\",))\n samples = pmap_numpyro.get_samples(group_by_chain=True)\n leapfrogs_taken = pmap_numpyro.get_extra_fields(group_by_chain=True)[\"num_steps\"]\n return samples, leapfrogs_taken\n\nprint(\"Compiling...\")\ntic2 = pd.Timestamp.now()\nposterior_pymc3, leapfrogs_taken_pymc3 = _sample(init_state_batched_, map_seed)\nleapfrogs_taken_pymc3.block_until_ready()\n# map_seed = jax.random.split(seed, chains)\n# mcmc_samples = _sample(init_state_batched, map_seed)\n# tic4 = pd.Timestamp.now()\n# print(\"Sampling time = \", tic4 - tic3)\n\ntic3 = pd.Timestamp.now()\nprint(\"Compilation + sampling time = \", tic3 - tic2)", | |
"execution_count": 22, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": "Compiling...\nCompilation + sampling time = 0 days 00:06:11.912123\nCPU times: user 14min 45s, sys: 17.2 s, total: 15min 2s\nWall time: 6min 12s\n" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "QVNVrt7rPaZK", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "az_trace_pymc3 = az.from_dict(posterior={k:v for k, v in zip(['step_size','volatility','nu'], posterior_pymc3)})", | |
"execution_count": 16, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "XGrthQIXM-3j", | |
"outputId": "f4b00a42-1c1b-4d0c-ff6c-b86afe7b0832", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "leapfrogs_taken_pymc3.mean(), leapfrogs_taken.mean()", | |
"execution_count": 23, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "(DeviceArray(173.944, dtype=float64), DeviceArray(140.28, dtype=float64))" | |
}, | |
"execution_count": 23, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "leapfrogs_taken_pymc3", | |
"execution_count": 18, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "DeviceArray([[127, 127, 127, ..., 127, 127, 127],\n [127, 127, 127, ..., 127, 127, 127],\n [127, 255, 255, ..., 255, 127, 255],\n [127, 255, 127, ..., 255, 127, 127]], dtype=int64)" | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "leapfrogs_taken", | |
"execution_count": 19, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "DeviceArray([[127, 127, 127, ..., 127, 127, 127],\n [127, 127, 127, ..., 127, 127, 127],\n [255, 255, 127, ..., 127, 127, 127],\n [127, 127, 127, ..., 127, 127, 127]], dtype=int64)" | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "1ByuHSGZkAK7", | |
"outputId": "050ed585-f1ac-41ee-b7df-1f707e675240", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "az.plot_trace(az_trace_pymc3, var_names=[\"step_size\", \"nu\"])", | |
"execution_count": 25, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "array([[<AxesSubplot:title={'center':'step_size'}>,\n <AxesSubplot:title={'center':'step_size'}>],\n [<AxesSubplot:title={'center':'nu'}>,\n <AxesSubplot:title={'center':'nu'}>]], dtype=object)" | |
}, | |
"execution_count": 25, | |
"metadata": { | |
"tags": [] | |
}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "cdc9DMLRkAK8", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "az.plot_trace(az_trace, var_names=[\"step_size\", \"nu\"])", | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "Xk6FvW3LkAK8", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "from jax import make_jaxpr", | |
"execution_count": 24, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "u8uXffVEkAK8", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "print(make_jaxpr(logp_fn_jax)(*init_state))", | |
"execution_count": 25, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": "{ lambda o bh bk bn bq cc cv cw cy dr dt dw ; a b c.\n let d = exp a\n e = mul 10.0 d\n f = sub 2.302585092994046 e\n g = add f a\n h = reduce_sum[ axes=() ] g\n i = reduce_sum[ axes=() ] h\n j = broadcast_in_dim[ broadcast_dimensions=()\n shape=(1,) ] i\n k = exp a\n l = pow k -2.0\n m = mul 1.0 l\n n = neg m\n p = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))\n slice_sizes=(1,) ] b o\n q = sub p 0.0\n r = pow q 2.0\n s = mul n r\n t = exp a\n u = pow t -2.0\n v = mul 1.0 u\n w = div v 3.141592653589793\n x = div w 2.0\n y = log x\n z = add s y\n ba = div z 2.0\n bb = exp a\n bc = mul 1.0 bb\n bd = pow bc -2.0\n be = mul 1.0 bd\n bf = neg be\n bg = reshape[ dimensions=None\n new_sizes=(1,) ] bf\n bi = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(), start_index_map=(0,))\n slice_sizes=(2904,) ] b bh\n bj = broadcast_in_dim[ broadcast_dimensions=(0,)\n shape=(2904,) ] bi\n bl = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(), start_index_map=(0,))\n slice_sizes=(2904,) ] b bk\n bm = broadcast_in_dim[ broadcast_dimensions=(0,)\n shape=(2904,) ] bl\n bo = add bm bn\n bp = sub bj bo\n br = pow bp bq\n bs = mul bg br\n bt = exp a\n bu = mul 1.0 bt\n bv = pow bu -2.0\n bw = mul 1.0 bv\n bx = div bw 3.141592653589793\n by = div bx 2.0\n bz = log by\n ca = reshape[ dimensions=None\n new_sizes=(1,) ] bz\n cb = add bs ca\n cd = div cb cc\n ce = reduce_sum[ axes=(0,) ] cd\n cf = add ba ce\n cg = reduce_sum[ axes=() ] cf\n ch = reduce_sum[ axes=() ] cg\n ci = broadcast_in_dim[ broadcast_dimensions=()\n shape=(1,) ] ch\n cj = exp c\n ck = mul 0.1 cj\n cl = sub -2.3025850929940455 ck\n cm = add cl c\n cn = reduce_sum[ axes=() ] cm\n co = reduce_sum[ axes=() ] cn\n cp = broadcast_in_dim[ broadcast_dimensions=()\n shape=(1,) ] co\n cq = exp c\n cr = add cq 1.0\n cs = div cr 2.0\n ct = lgamma cs\n cu = reshape[ dimensions=None\n new_sizes=(1,) ] ct\n cx = exp b\n cz = pow cx cy\n da = mul cw cz\n db = exp c\n dc = mul db 3.141592653589793\n dd = reshape[ dimensions=None\n new_sizes=(1,) ] dc\n de = div da dd\n df = log de\n dg = mul cv df\n dh = add cu dg\n di = exp c\n dj = div di 2.0\n dk = lgamma dj\n dl = reshape[ dimensions=None\n new_sizes=(1,) ] dk\n dm = sub dh dl\n dn = exp c\n do = add dn 1.0\n dp = div do 2.0\n dq = reshape[ dimensions=None\n new_sizes=(1,) ] dp\n ds = exp b\n du = pow ds dt\n dv = mul dr du\n dx = mul dv dw\n dy = exp c\n dz = reshape[ dimensions=None\n new_sizes=(1,) ] dy\n ea = div dx dz\n eb = log1p ea\n ec = mul dq eb\n ed = sub dm ec\n ee = reduce_sum[ axes=(0,) ] ed\n ef = reduce_sum[ axes=() ] ee\n eg = broadcast_in_dim[ broadcast_dimensions=()\n shape=(1,) ] ef\n eh = concatenate[ dimension=0 ] j ci cp eg\n ei = reduce_sum[ axes=(0,) ] eh\n in (ei,) }\n" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "3sCUA5pIkAK9", | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "print(make_jaxpr(model_info.potential_fn)(z2))", | |
"execution_count": 26, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": "{ lambda m v y bz ; a b c.\n let d = broadcast_in_dim[ broadcast_dimensions=()\n shape=(1,) ] b\n e = reduce_sum[ axes=(0,) ] d\n f = reduce_sum[ axes=() ] e\n g = add 0.0 f\n h = exp b\n i = mul h 10.0\n j = sub 2.302585092994046 i\n k = reduce_sum[ axes=() ] j\n l = add g k\n n = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))\n slice_sizes=(1,) ] c m\n o = sub n 0.0\n p = div o h\n q = integer_pow[ y=2 ] p\n r = mul q -0.5\n s = mul 2.5066282746310002 h\n t = log s\n u = sub r t\n w = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(), start_index_map=(0,))\n slice_sizes=(2904,) ] c v\n x = broadcast_in_dim[ broadcast_dimensions=(0,)\n shape=(2904,) ] w\n z = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(), start_index_map=(0,))\n slice_sizes=(2904,) ] c y\n ba = broadcast_in_dim[ broadcast_dimensions=(0,)\n shape=(2904,) ] z\n bb = sub x ba\n bc = broadcast_in_dim[ broadcast_dimensions=()\n shape=(1,) ] h\n bd = div bb bc\n be = integer_pow[ y=2 ] bd\n bf = mul be -0.5\n bg = mul 2.5066282746310002 bc\n bh = log bg\n bi = sub bf bh\n bj = reduce_sum[ axes=(0,) ] bi\n bk = add u bj\n bl = reduce_sum[ axes=() ] bk\n bm = add l bl\n bn = broadcast_in_dim[ broadcast_dimensions=()\n shape=(1,) ] a\n bo = reduce_sum[ axes=(0,) ] bn\n bp = reduce_sum[ axes=() ] bo\n bq = add bm bp\n br = exp a\n bs = mul br 0.1\n bt = sub -2.3025850929940455 bs\n bu = reduce_sum[ axes=() ] bt\n bv = add bq bu\n bw = broadcast_in_dim[ broadcast_dimensions=()\n shape=(2905,) ] br\n bx = add bw 1.0\n by = mul bx -0.5\n ca = exp c\n cb = div bz ca\n cc = pow cb 2.0\n cd = div cc bw\n ce = log1p cd\n cf = mul by ce\n cg = log ca\n ch = log bw\n ci = mul ch 0.5\n cj = add cg ci\n ck = add cj 0.5723649429247001\n cl = mul bw 0.5\n cm = lgamma cl\n cn = add ck cm\n co = add bw 1.0\n cp = mul co 0.5\n cq = lgamma cp\n cr = sub cn cq\n cs = sub cf cr\n ct = reduce_sum[ axes=(0,) ] cs\n cu = add bv ct\n cv = neg cu\n in (cv,) }\n" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "i0qhlJo6kAK9" | |
}, | |
"cell_type": "markdown", | |
"source": "## References\n\n1. Hoffman & Gelman. (2011). [The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo](http://arxiv.org/abs/1111.4246). " | |
} | |
], | |
"metadata": { | |
"_draft": { | |
"nbviewer_url": "https://gist.github.com/a77104299535b64b58953de3c84df56f" | |
}, | |
"anaconda-cloud": {}, | |
"colab": { | |
"collapsed_sections": [], | |
"include_colab_link": true, | |
"name": "c8b884797f950d1ef033ca69b253a4a0#file-stochastic_volatility_jax-ipynb", | |
"provenance": [] | |
}, | |
"gist": { | |
"id": "a77104299535b64b58953de3c84df56f", | |
"data": { | |
"description": "stochastic_volatility.ipynb", | |
"public": true | |
} | |
}, | |
"kernelspec": { | |
"name": "pymc3theano", | |
"display_name": "pymc3theano", | |
"language": "python" | |
}, | |
"language_info": { | |
"name": "python", | |
"version": "3.8.5", | |
"mimetype": "text/x-python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"pygments_lexer": "ipython3", | |
"nbconvert_exporter": "python", | |
"file_extension": ".py" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
How do you mean?
…On Mon, Apr 12, 2021, 21:01 Brandon T. Willard ***@***.***> wrote:
***@***.**** commented on this gist.
------------------------------
I just noticed that this example isn't optimizing the FunctionGraph.
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<https://gist.github.com/a77104299535b64b58953de3c84df56f#gistcomment-3703154>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAFETGGVJP47QKPRU54KYL3TIM7PFANCNFSM42Z3R32A>
.
Doing something like the following will optimize the FunctionGraph
in roughly the same way that aesara.function
does:
from aesara.compile.mode import FAST_RUN
fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
_ = FAST_RUN.optimizer.optimize(fgraph)
Without that step, the JAX function will take the exact form of the log-likelihood graph determined by the Distribution.logp
implementations (i.e. no CSE, fusions, in-place operations, etc.).
I suppose pm.sample()
already does this?
This looks like something we need to update in PyMC3, as well.
Here's a quick comparison of the timing with and without graph optimizations (the example/model
is taken from this notebook):
fgraph = model.logp.f.maker.fgraph
...
get_ipython().run_line_magic('timeit', 'fn2(*init_state).block_until_ready()')
# 198 µs ± 18.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
fgraph = aesara.graph.fg.FunctionGraph(model.free_RVs, [model.logpt])
...
get_ipython().run_line_magic('timeit', 'fn2(*init_state).block_until_ready()')
# 236 µs ± 10.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I just noticed that this example isn't optimizing the
FunctionGraph
.