Last active
January 1, 2019 23:46
-
-
Save akelleh/6ad00740e94029fa25fa953e567642da to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Let's see if we can do MCMC online, with a dynamics correction. The trick is to remember posteriors from the previous training run, and use those as priors at the next run. Then, we only have to train on the new batch of data at each step, so MCMC runs faster. It's still not performant enough for most real-time applications, but it's pretty quick!\n", | |
| "\n", | |
| "We'll also apply a dynamics correction, so we can shift the mean and variance back toward the prior after each update. That way, if the world changes (i.e. beta), our model can account for it! Inspired by https://github.com/ajtulloch/adpredictor" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pymc3 as pm\n", | |
| "import pandas as pd\n", | |
| "import numpy as np" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Generate data from a toy model, which is just a 1-D linear model. Beta = 3." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 38, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def get_data():\n", | |
| " x = np.random.normal()\n", | |
| " y = np.random.normal(3. * x)\n", | |
| " return x, y\n", | |
| "\n", | |
| "\n", | |
| "def generate_data(N):\n", | |
| " for _ in range(N):\n", | |
| " yield get_data()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Now, train the model on the first batch of data using pretty bad priors to kick of the process. We'll save the posterior hyperparameters, so we can use those as the model's hyperparameters at the next time step! That's how we'll carry the learned information forward." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 66, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [sigma, beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 5655.44draws/s]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "x = []\n", | |
| "y = []\n", | |
| "for i in range(50):\n", | |
| " xi, yi = get_data()\n", | |
| " x.append(xi)\n", | |
| " y.append(yi)\n", | |
| " \n", | |
| "parameter_prior_mu = 0\n", | |
| "parameter_prior_sd = 1\n", | |
| "\n", | |
| "with pm.Model() as model:\n", | |
| " beta = pm.Normal('beta', mu=parameter_prior_mu, sd=parameter_prior_sd)\n", | |
| " x = pm.Normal('x', mu=0, sd=1, observed=x)\n", | |
| " y = pm.Normal('y', mu=beta * x, sd=1, observed=y)\n", | |
| " trace = pm.sample()\n", | |
| " \n", | |
| "beta_mu = trace['beta'].mean()\n", | |
| "beta_sd = trace['beta'].std()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 69, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[<matplotlib.axes._subplots.AxesSubplot object at 0x7f5ee8117160>,\n", | |
| " <matplotlib.axes._subplots.AxesSubplot object at 0x7f5ee7c1b7b8>]],\n", | |
| " dtype=object)" | |
| ] | |
| }, | |
| "execution_count": 69, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 864x144 with 2 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "pm.traceplot(trace)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 60, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "2.82860702294561" | |
| ] | |
| }, | |
| "execution_count": 60, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "beta_mu" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Now, we want to gatcher small batches of data (say, 100 points) and update the model with these batches. We don't want to have to re-train on old data points, and we don't want to have to train on an increasingly growing data set. We don't have to!\n", | |
| "\n", | |
| "The dynamics correction adjusts the hyperparameters back toward the values from the prior we started with. That incorporates a kind of \"forgetting\", so if the world changes, the model can adapt quickly." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 70, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO (theano.gof.compilelock): Refreshing lock /home/akelleh/.theano/compiledir_Linux-4.15--generic-x86_64-with-Ubuntu-18.04-bionic-x86_64-3.6.6-64/lock_dir/lock\n", | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 8880.81draws/s]\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.053454970949981. stddev is 0.03973987065745886\n", | |
| "after dynamics, mean is 3.052919265939124. stddev is 0.0418858269772828\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7459.47draws/s]\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.048594287797906. stddev is 0.04015923200119004\n", | |
| "after dynamics, mean is 3.0480480898012186. stddev is 0.04232775509017777\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7901.62draws/s]\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.047618209780102. stddev is 0.038096180241445156\n", | |
| "after dynamics, mean is 3.0471268372335634. stddev is 0.040153662492752075\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7471.52draws/s]\n", | |
| "Auto-assigning NUTS sampler...\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.049095470409351. stddev is 0.03783759687113559\n", | |
| "after dynamics, mean is 3.048610509732998. stddev is 0.039881157159260834\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7473.20draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.8837659016340619, but should be close to 0.8. Try to increase the number of tuning steps.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.0279559253022192. stddev is 0.03731755398351496\n", | |
| "after dynamics, mean is 3.0274874721294305. stddev is 0.03933311281097138\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 8480.02draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.8848219389866341, but should be close to 0.8. Try to increase the number of tuning steps.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.040714538676363. stddev is 0.034905193784378084\n", | |
| "after dynamics, mean is 3.04030295848994. stddev is 0.03679081465373769\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 8011.67draws/s]\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.0353504527168753. stddev is 0.035061528490238726\n", | |
| "after dynamics, mean is 3.0349359105578. stddev is 0.036955572295779976\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7970.02draws/s]\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.045667078474604. stddev is 0.03536037336528721\n", | |
| "after dynamics, mean is 3.04524400746016. stddev is 0.03727051737182386\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7683.32draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.87886307401005, but should be close to 0.8. Try to increase the number of tuning steps.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.0409465687452437. stddev is 0.035903428178477276\n", | |
| "after dynamics, mean is 3.0405110810204397. stddev is 0.03784282629622187\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 8668.13draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.8790482982354095, but should be close to 0.8. Try to increase the number of tuning steps.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.0289595408750447. stddev is 0.03502965328282971\n", | |
| "after dynamics, mean is 3.0285466232393787. stddev is 0.03692197975360412\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| " def apply_dynamics(posterior_mu, posterior_sd, prior_mu, prior_sd, epsilon=0.1):\n", | |
| " adjusted_variance = (posterior_sd * prior_sd) ** 2. / \\\n", | |
| " ((1.0 - epsilon) * prior_sd**2. +\n", | |
| " epsilon * posterior_sd ** 2.)\n", | |
| " adjusted_mean = adjusted_variance * (\n", | |
| " (1.0 - epsilon) * posterior_mu / posterior_sd ** 2. +\n", | |
| " epsilon * prior_mu / prior_sd ** 2.)\n", | |
| "\n", | |
| " return adjusted_mean, np.sqrt(adjusted_variance)\n", | |
| "\n", | |
| "\n", | |
| "batches = 10\n", | |
| "batch_size = 100\n", | |
| "for batch in range(batches):\n", | |
| " x_batch = []\n", | |
| " y_batch = []\n", | |
| " for xi, yi in generate_data(batch_size):\n", | |
| " x_batch.append(xi)\n", | |
| " y_batch.append(yi)\n", | |
| " with pm.Model() as model:\n", | |
| " beta = pm.Normal('beta', mu=beta_mu, sd=beta_sd)\n", | |
| " x = pm.Normal('x', mu=0, sd=1, observed=x_batch)\n", | |
| " y = pm.Normal('y', mu=beta * x, sd=1, observed=y_batch)\n", | |
| " trace = pm.sample()\n", | |
| " \n", | |
| " beta_mu = trace['beta'].mean()\n", | |
| " beta_sd = trace['beta'].std()\n", | |
| "\n", | |
| " print(\"mean is {}. stddev is {}\".format(beta_mu, beta_sd))\n", | |
| " beta_mu, beta_sd = apply_dynamics(beta_mu, beta_sd, parameter_prior_mu, parameter_prior_sd)\n", | |
| " print(\"after dynamics, mean is {}. stddev is {}\".format(beta_mu, beta_sd))\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "\n", | |
| "How bad is it? The posterior variances should decrease without the dynamics correction like $1/\\sqrt{N}$. Let's see how it compares." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 78, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO (theano.gof.compilelock): Refreshing lock /home/akelleh/.theano/compiledir_Linux-4.15--generic-x86_64-with-Ubuntu-18.04-bionic-x86_64-3.6.6-64/lock_dir/lock\n", | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 8060.00draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.8819365653125937, but should be close to 0.8. Try to increase the number of tuning steps.\n", | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7722.36draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.8875722379013747, but should be close to 0.8. Try to increase the number of tuning steps.\n", | |
| "The acceptance probability does not match the target. It is 0.8786981772567996, but should be close to 0.8. Try to increase the number of tuning steps.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.0978448956341986. stddev is 0.08712304227728766\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7377.37draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.883853341198793, but should be close to 0.8. Try to increase the number of tuning steps.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.073595445041903. stddev is 0.06502291943483667\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7627.26draws/s]\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.0792598915797202. stddev is 0.053708331074624895\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7502.54draws/s]\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.109806521475409. stddev is 0.04737354009513188\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7513.91draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.8925263776307877, but should be close to 0.8. Try to increase the number of tuning steps.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.074426002545947. stddev is 0.04408508675694674\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7643.29draws/s]\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.072258736919936. stddev is 0.04082190569412854\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7563.38draws/s]\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.059100197616981. stddev is 0.037263151474141505\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7708.45draws/s]\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.0326486866503157. stddev is 0.034850653454787696\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 8003.59draws/s]\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.0030969622879424. stddev is 0.03247572577988825\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 8463.38draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.8786672009219315, but should be close to 0.8. Try to increase the number of tuning steps.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.011665185873131. stddev is 0.030224481110303854\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "N = 50\n", | |
| "\n", | |
| "x = []\n", | |
| "y = []\n", | |
| "for i in range(N):\n", | |
| " xi, yi = get_data()\n", | |
| " x.append(xi)\n", | |
| " y.append(yi)\n", | |
| " \n", | |
| "parameter_prior_mu = 0\n", | |
| "parameter_prior_sd = 1\n", | |
| "\n", | |
| "with pm.Model() as model:\n", | |
| " beta = pm.Normal('beta', mu=parameter_prior_mu, sd=parameter_prior_sd)\n", | |
| " x = pm.Normal('x', mu=0, sd=1, observed=x)\n", | |
| " y = pm.Normal('y', mu=beta * x, sd=1, observed=y)\n", | |
| " trace = pm.sample()\n", | |
| " \n", | |
| "beta_mu = trace['beta'].mean()\n", | |
| "beta_sd = trace['beta'].std()\n", | |
| "\n", | |
| "\n", | |
| "betas = [beta_mu]\n", | |
| "variances = [beta_sd**2.]\n", | |
| "Ns = [N]\n", | |
| "expected_vars = [parameter_prior_sd**2./ np.sqrt(N)]\n", | |
| "\n", | |
| "batches = 10\n", | |
| "batch_size = 100\n", | |
| "for batch in range(batches):\n", | |
| " x_batch = []\n", | |
| " y_batch = []\n", | |
| " for xi, yi in generate_data(batch_size):\n", | |
| " x_batch.append(xi)\n", | |
| " y_batch.append(yi)\n", | |
| " with pm.Model() as model:\n", | |
| " beta = pm.Normal('beta', mu=beta_mu, sd=beta_sd)\n", | |
| " x = pm.Normal('x', mu=0, sd=1, observed=x_batch)\n", | |
| " y = pm.Normal('y', mu=beta * x, sd=1, observed=y_batch)\n", | |
| " trace = pm.sample()\n", | |
| " \n", | |
| " beta_mu = trace['beta'].mean()\n", | |
| " beta_sd = trace['beta'].std()\n", | |
| "\n", | |
| " betas.append(beta_mu)\n", | |
| " variances.append(beta_sd**2.)\n", | |
| " Ns.append( N + (batch+1)*batch_size)\n", | |
| " expected_vars.append(parameter_prior_sd**2. / np.sqrt(Ns[-1]))\n", | |
| " print(\"mean is {}. stddev is {}\".format(beta_mu, beta_sd))\n", | |
| " \n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 93, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO (theano.gof.compilelock): Refreshing lock /home/akelleh/.theano/compiledir_Linux-4.15--generic-x86_64-with-Ubuntu-18.04-bionic-x86_64-3.6.6-64/lock_dir/lock\n", | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7400.63draws/s]\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.104672691628748. stddev is 0.2225481433348595\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7450.62draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.8888517368030993, but should be close to 0.8. Try to increase the number of tuning steps.\n", | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7478.95draws/s]\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.003328641749641. stddev is 0.15686252520724742\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7503.50draws/s]\n", | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7916.13draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.8888287061248509, but should be close to 0.8. Try to increase the number of tuning steps.\n", | |
| "The acceptance probability does not match the target. It is 0.8873024355582213, but should be close to 0.8. Try to increase the number of tuning steps.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.149859057283542. stddev is 0.1321950041134763\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7747.25draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.8804837236434276, but should be close to 0.8. Try to increase the number of tuning steps.\n", | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 8737.66draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.893731802366368, but should be close to 0.8. Try to increase the number of tuning steps.\n", | |
| "The acceptance probability does not match the target. It is 0.896683830655603, but should be close to 0.8. Try to increase the number of tuning steps.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.0837114051214187. stddev is 0.11457296455456885\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7907.50draws/s]\n", | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7854.99draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.8898292203217597, but should be close to 0.8. Try to increase the number of tuning steps.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.084813065588017. stddev is 0.10123663776510136\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7201.66draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.8829888505132729, but should be close to 0.8. Try to increase the number of tuning steps.\n", | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7540.10draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.8823576288430245, but should be close to 0.8. Try to increase the number of tuning steps.\n", | |
| "The acceptance probability does not match the target. It is 0.878892621323462, but should be close to 0.8. Try to increase the number of tuning steps.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.061054593890492. stddev is 0.09181683943190598\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7791.47draws/s]\n", | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 8152.38draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.8842403911583585, but should be close to 0.8. Try to increase the number of tuning steps.\n", | |
| "The acceptance probability does not match the target. It is 0.8797061253564031, but should be close to 0.8. Try to increase the number of tuning steps.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.0478423643216828. stddev is 0.08462467428117593\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7229.62draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.8814150559687265, but should be close to 0.8. Try to increase the number of tuning steps.\n", | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7813.77draws/s]\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.0492225191643305. stddev is 0.08142210479745383\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7362.88draws/s]\n", | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7659.84draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.8966364277221469, but should be close to 0.8. Try to increase the number of tuning steps.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.0581610994204396. stddev is 0.0756976874336479\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7942.67draws/s]\n", | |
| "The acceptance probability does not match the target. It is 0.8929763677027426, but should be close to 0.8. Try to increase the number of tuning steps.\n", | |
| "The acceptance probability does not match the target. It is 0.8836011362503814, but should be close to 0.8. Try to increase the number of tuning steps.\n", | |
| "Auto-assigning NUTS sampler...\n", | |
| "Initializing NUTS using jitter+adapt_diag...\n", | |
| "Multiprocess sampling (4 chains in 4 jobs)\n", | |
| "NUTS: [beta]\n", | |
| "Sampling 4 chains: 100%|██████████| 4000/4000 [00:00<00:00, 7812.79draws/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mean is 3.008154997317253. stddev is 0.07184543377033048\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "batched_data = []\n", | |
| "batches = 10\n", | |
| "batch_size = 20\n", | |
| "for batch in range(batches):\n", | |
| " x_batch = []\n", | |
| " y_batch = []\n", | |
| " for xi, yi in generate_data(batch_size):\n", | |
| " x_batch.append(xi)\n", | |
| " y_batch.append(yi)\n", | |
| " batched_data.append((x_batch, y_batch))\n", | |
| " \n", | |
| "parameter_prior_mu = 0\n", | |
| "parameter_prior_sd = 1\n", | |
| "\n", | |
| "batches = 10\n", | |
| "batch_size = 100\n", | |
| "for batch, batch_data in enumerate(batched_data):\n", | |
| " if batch == 0:\n", | |
| " x, y = batch_data\n", | |
| " x_running, y_running = batch_data\n", | |
| " with pm.Model() as model:\n", | |
| " beta = pm.Normal('beta', mu=parameter_prior_mu, sd=parameter_prior_sd)\n", | |
| " x = pm.Normal('x', mu=0, sd=1, observed=x)\n", | |
| " y = pm.Normal('y', mu=beta * x, sd=1, observed=y)\n", | |
| " trace = pm.sample()\n", | |
| "\n", | |
| " beta_mu = trace['beta'].mean()\n", | |
| " beta_sd = trace['beta'].std()\n", | |
| " \n", | |
| " betas = [beta_mu]\n", | |
| " mcmc_variances = [beta_sd**2.]\n", | |
| " mcmc_online_variances = [beta_sd**2.]\n", | |
| " Ns = [N]\n", | |
| " else:\n", | |
| "\n", | |
| " x_batch, y_batch = batch_data\n", | |
| " x_running, y_running = x_running + x_batch, y_running + y_batch\n", | |
| " with pm.Model() as model:\n", | |
| " beta = pm.Normal('beta', mu=beta_mu, sd=beta_sd)\n", | |
| " x = pm.Normal('x', mu=0, sd=1, observed=x_batch)\n", | |
| " y = pm.Normal('y', mu=beta * x, sd=1, observed=y_batch)\n", | |
| " trace = pm.sample()\n", | |
| "\n", | |
| " beta_mu = trace['beta'].mean()\n", | |
| " beta_sd = trace['beta'].std()\n", | |
| "\n", | |
| " mcmc_online_variances.append(beta_sd**2.)\n", | |
| " with pm.Model() as model:\n", | |
| " beta = pm.Normal('beta', mu=parameter_prior_mu, sd=parameter_prior_sd)\n", | |
| " x = pm.Normal('x', mu=0, sd=1, observed=x_running)\n", | |
| " y = pm.Normal('y', mu=beta * x, sd=1, observed=y_running)\n", | |
| " trace = pm.sample()\n", | |
| "\n", | |
| " beta_mu = trace['beta'].mean()\n", | |
| " beta_sd = trace['beta'].std()\n", | |
| "\n", | |
| " mcmc_variances.append(beta_sd**2.)\n", | |
| " Ns.append( N + (batch+1)*batch_size)\n", | |
| " print(\"mean is {}. stddev is {}\".format(beta_mu, beta_sd))\n", | |
| " \n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 92, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<matplotlib.axes._subplots.AxesSubplot at 0x7f5ee43ec780>" | |
| ] | |
| }, | |
| "execution_count": 92, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "pd.DataFrame({'$\\sigma^2_{online}$': mcmc_online_variances, \n", | |
| " '$N$': Ns, \n", | |
| " '$\\sigma^2_{batch}$': mcmc_variances}).plot(x='$N$', y=['$\\sigma^2_{online}$', '$\\sigma^2_{batch}$'])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Apparently not bad at all!" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.6" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment