Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save ricardoV94/f986686ce86511b293c5dd6be374e51d to your computer and use it in GitHub Desktop.
Save ricardoV94/f986686ce86511b293c5dd6be374e51d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Context\n",
"\n",
"This notebook is a follow up to an excelent [notebook](https://colab.research.google.com/drive/1kRW-X2z5GBsNFtN0dDvPo21yGT0ASosJ?usp=sharing) by [ckrapu](https://discourse.pymc.io/u/ckrapu), originally shared in this [Discourse issue](https://discourse.pymc.io/t/hierarchical-changepoint-detection/10789)\n",
"\n",
"It attempts to sample the first iteration model, marginalizing over the number of changepoints. This relies on a not-yet merged PR in pymc-experimental: https://github.com/pymc-devs/pymcx/pull/91"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D8VNQICRbIEK"
},
"source": [
"# Changepoint models\n",
"\n",
"Identifying structural breaks in data is an important problem to folks that frequently work with time series data. Some examples of how people have dealt with the problem of a single changepoint can be found [here](https://cscherrer.github.io/post/bayesian-changepoint/) and [here](https://mc-stan.org/docs/2_23/stan-users-guide/change-point-section.html). \n",
"\n",
"Generally, the setup looks like this: we have some data $X_t$ indexed by a discrete time coordinate $t \\in \\{1,...,T\\}$ and a parametric submodel linking the distribution of $X$ to another quantity $\\mu_t$ which depends on the temporal coordinate and which render $X_1,...,X_T$ conditionally independent given the quantity $\\mu_t$. For the case of a linear Gaussian model with a single change point, we have\n",
"\n",
"\n",
"$$a_1, a_2 \\sim N(0, \\sigma^2_\\mu)$$\n",
"\n",
"$$\\tau \\sim \\text{DiscreteUniform}(\\{1,...,T\\})$$\n",
"\n",
"$$\\mu_t = \\left\\{\n",
" \\begin{array}{l}\n",
" a_1 \\text{ if } t < \\tau \\\\\n",
" a_2 \\text{ if } t \\ge \\tau\n",
" \\end{array}\n",
" \\right. $$\n",
"\n",
"$$X_t \\sim N(\\mu_t, \\sigma_\\epsilon)$$\n",
"\n",
"with your scale priors of choice on the variance parameters $\\sigma_\\epsilon$ and $\\sigma_\\mu$. Now, one of the main conceptual problems with this model is that you need to assume it has a single changepoint. You can relax that assumption by extending this model to include more $\\tau$ and $a$ parameters, but you'll still need to specify the number of them ahead of time. \n",
"\n",
"Relaxing the assumption on the number of parameters is, for the most part, a solved problem in the research community (see [here](https://www.sciencedirect.com/science/article/abs/pii/S0167715297000503) and [here](https://repository.upenn.edu/cgi/viewcontent.cgi?article=1376&context=statistics_papers) for a few representative examples). Unfortunately, these require the analyst to implement the inference techniques presented by hand; these are often Gibbs samplers or similar. Wouldn't it be nice to just be able to use a PPL and write down the forward process instead?\n",
"\n",
"That's the point of this notebook - we'll walk through a construction of a changepoint model plus inference in PyMC which is considerably more straightforward than a handwritten sampler. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XuRuGxmcfbDX"
},
"source": [
"We'll start by simulating some data over 50 timesteps; there are 4 changepoints \n",
"and the model's likelihood is Gaussian. We will use a standard set of imports for working with PyMC and set the seed for repeatability."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "xoHV2lhYbIf8"
},
"outputs": [],
"source": [
"import arviz as az\n",
"import pymc as pm\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import scipy.special as sp\n",
"import aesara.tensor as at\n",
"from collections import Counter\n",
"from IPython.display import set_matplotlib_formats\n",
"import xarray as xr"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MAMECjBVU6lG"
},
"source": [
"# Simulating a dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gEdVVufPiP6x"
},
"source": [
"Since the generative process for this data is simple, the code required to simulate data is relatively short. We begin by sampling the changepoints and then adding offsets for each changepoint to the mean value of the data. We then perturb this mean with normal noise variates to create simulated observations."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "f3NdjpDy_Ifb"
},
"outputs": [],
"source": [
"np.random.seed(827)\n",
"rng = np.random.default_rng(827)\n",
"\n",
"T = 50\n",
"noise_sd_true = 0.15\n",
"n_cps_true = 4\n",
"\n",
"def simulate_data(T, n_changepoints, noise_sd=0.15):\n",
" cp_times = np.sort(np.random.choice(T, size=n_cps_true))\n",
" cp_deltas = np.random.randn(n_cps_true)\n",
"\n",
" noiseless = np.zeros(T)\n",
" start_time = 0\n",
"\n",
" for cp_time, cp_delta in zip(cp_times, cp_deltas):\n",
" noiseless[start_time:cp_time] += cp_delta\n",
" start_time = cp_time\n",
"\n",
" xs = noiseless + np.random.randn(T) * noise_sd_true\n",
" return xs, noiseless, cp_times, cp_deltas \n",
"\n",
"xs, noiseless, cp_times_true, cp_deltas_true = simulate_data(T, n_cps_true)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 9, 22, 37, 45])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cp_times_true"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([-0.26203003, 0.47949834, -0.37238606, 0.09406072])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cp_deltas_true"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s_zDGssQihSd"
},
"source": [
"As we can see below, the green changepoints do clearly correspond to changes in the level of the time series. However, not all of them are obvious - the last one, in particular, is a relatively small jump."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 301
},
"id": "kTssO46aiN3T",
"outputId": "83cd9247-6a72-4205-dd87-9b7eab93eff1"
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 900x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(9,3))\n",
"plt.plot(xs, marker='o', color='k', label='Observed data')\n",
"plt.plot(noiseless, color='g', label='Noise-free mean value')\n",
"\n",
"for i, cp in enumerate(cp_times_true):\n",
" if i == 0:\n",
" label = 'Change point'\n",
" else:\n",
" label=None\n",
" plt.axvline(cp, color='g', linestyle='--',label=label)\n",
"\n",
"plt.legend()\n",
"\n",
"plt.xlabel('Timestep',fontsize=12)\n",
"plt.ylabel('$X(t)$',fontsize=18);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model with changepoints marginalized"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pm.DiscreteUniform.dist(0, 2, size=2000).eval().max()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.random.choice(2, size=2000).max()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will try to help NUTS by marginalizing over the number of changepoints, using experimental featuer from pymc_experimental"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from pymc_experimental.marginal_model import MarginalModel"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"max_cp_inference = 10\n",
"tiled_times = np.arange(T)[:, None].repeat(max_cp_inference, axis=1)\n",
"\n",
"with MarginalModel() as m:\n",
" # Not implemented yet, but otherwise a nice way to specify more informed prior\n",
"# n_cps = pm.Truncated(\"n_cps\", pm.Poisson.dist(5), 0, T)\n",
" n_cps = pm.DiscreteUniform(\"n_cps\", 0, max_cp_inference)\n",
" \n",
" # We can't marginalize `cp_times` because each dim contributes non-independently to the likelihood\n",
" # It doesn't make sense to consider cp_times of 0 or T, as we can't observe changes\n",
" cp_times = pm.DiscreteUniform(\"cp_times\", 1, T-1, shape=max_cp_inference)\n",
"\n",
" cp_sd = pm.HalfNormal('cp_sd', sigma=2)\n",
" cp_deltas = pm.Normal('cp_deltas', cp_sd, shape=max_cp_inference)\n",
" \n",
" global_mean = pm.Normal('global_mean', sigma=1)\n",
" noise_sd = pm.HalfNormal('noise_sd', sigma=1)\n",
" \n",
" cp_times_sorted = cp_times.sort()\n",
" is_timestep_past_cp = (tiled_times >= cp_times_sorted[None, :].repeat(T, axis=0))\n",
" cp_contrib = at.sum(cp_deltas[:n_cps] * is_timestep_past_cp[:, :n_cps], axis=1)\n",
" \n",
" mu = global_mean + cp_contrib\n",
" pm.Normal('likelihood', mu=mu, sigma=noise_sd, observed=xs)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"m.marginalize([n_cps])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 6.0.1 (20220911.2005)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"680pt\" height=\"355pt\"\n",
" viewBox=\"0.00 0.00 679.71 354.86\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 350.86)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-350.86 675.71,-350.86 675.71,4 -4,4\"/>\n",
"<g id=\"clust1\" class=\"cluster\">\n",
"<title>cluster10</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M331.71,-129.95C331.71,-129.95 651.71,-129.95 651.71,-129.95 657.71,-129.95 663.71,-135.95 663.71,-141.95 663.71,-141.95 663.71,-231.91 663.71,-231.91 663.71,-237.91 657.71,-243.91 651.71,-243.91 651.71,-243.91 331.71,-243.91 331.71,-243.91 325.71,-243.91 319.71,-237.91 319.71,-231.91 319.71,-231.91 319.71,-141.95 319.71,-141.95 319.71,-135.95 325.71,-129.95 331.71,-129.95\"/>\n",
"<text text-anchor=\"middle\" x=\"646.21\" y=\"-137.75\" font-family=\"Times,serif\" font-size=\"14.00\">10</text>\n",
"</g>\n",
"<g id=\"clust2\" class=\"cluster\">\n",
"<title>cluster50</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M254.71,-8C254.71,-8 368.71,-8 368.71,-8 374.71,-8 380.71,-14 380.71,-20 380.71,-20 380.71,-109.95 380.71,-109.95 380.71,-115.95 374.71,-121.95 368.71,-121.95 368.71,-121.95 254.71,-121.95 254.71,-121.95 248.71,-121.95 242.71,-115.95 242.71,-109.95 242.71,-109.95 242.71,-20 242.71,-20 242.71,-14 248.71,-8 254.71,-8\"/>\n",
"<text text-anchor=\"middle\" x=\"363.21\" y=\"-15.8\" font-family=\"Times,serif\" font-size=\"14.00\">50</text>\n",
"</g>\n",
"<!-- cp_deltas -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>cp_deltas</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"596.71\" cy=\"-198.43\" rx=\"58.88\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"596.71\" y=\"-209.73\" font-family=\"Times,serif\" font-size=\"14.00\">cp_deltas</text>\n",
"<text text-anchor=\"middle\" x=\"596.71\" y=\"-194.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"596.71\" y=\"-179.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- likelihood -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>likelihood</title>\n",
"<ellipse fill=\"lightgrey\" stroke=\"black\" cx=\"311.71\" cy=\"-76.48\" rx=\"60.62\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"311.71\" y=\"-87.78\" font-family=\"Times,serif\" font-size=\"14.00\">likelihood</text>\n",
"<text text-anchor=\"middle\" x=\"311.71\" y=\"-72.78\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"311.71\" y=\"-57.78\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- cp_deltas&#45;&gt;likelihood -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>cp_deltas&#45;&gt;likelihood</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M570.53,-164.63C558.94,-152.09 544.39,-138.71 528.71,-129.95 483.5,-104.7 426.38,-91.51 382.1,-84.68\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"382.4,-81.18 371.99,-83.19 381.38,-88.11 382.4,-81.18\"/>\n",
"</g>\n",
"<!-- cp_times -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>cp_times</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"423.71\" cy=\"-198.43\" rx=\"96.33\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"423.71\" y=\"-209.73\" font-family=\"Times,serif\" font-size=\"14.00\">cp_times</text>\n",
"<text text-anchor=\"middle\" x=\"423.71\" y=\"-194.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"423.71\" y=\"-179.73\" font-family=\"Times,serif\" font-size=\"14.00\">DiscreteUniform</text>\n",
"</g>\n",
"<!-- cp_times&#45;&gt;likelihood -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>cp_times&#45;&gt;likelihood</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M391.62,-163.06C378.18,-148.67 362.5,-131.87 348.59,-116.98\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"350.89,-114.31 341.51,-109.39 345.77,-119.09 350.89,-114.31\"/>\n",
"</g>\n",
"<!-- cp_sd -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>cp_sd</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"596.71\" cy=\"-309.38\" rx=\"70.92\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"596.71\" y=\"-320.68\" font-family=\"Times,serif\" font-size=\"14.00\">cp_sd</text>\n",
"<text text-anchor=\"middle\" x=\"596.71\" y=\"-305.68\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"596.71\" y=\"-290.68\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n",
"</g>\n",
"<!-- cp_sd&#45;&gt;cp_deltas -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>cp_sd&#45;&gt;cp_deltas</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M596.71,-271.8C596.71,-263.63 596.71,-254.85 596.71,-246.32\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"600.21,-246.1 596.71,-236.1 593.21,-246.1 600.21,-246.1\"/>\n",
"</g>\n",
"<!-- noise_sd -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>noise_sd</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"70.71\" cy=\"-198.43\" rx=\"70.92\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-209.73\" font-family=\"Times,serif\" font-size=\"14.00\">noise_sd</text>\n",
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-194.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"70.71\" y=\"-179.73\" font-family=\"Times,serif\" font-size=\"14.00\">HalfNormal</text>\n",
"</g>\n",
"<!-- noise_sd&#45;&gt;likelihood -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>noise_sd&#45;&gt;likelihood</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M103.2,-165.07C116.96,-152.75 133.72,-139.41 150.71,-129.95 179.91,-113.7 214.91,-101.54 244.81,-93.04\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"246.04,-96.33 254.74,-90.29 244.17,-89.58 246.04,-96.33\"/>\n",
"</g>\n",
"<!-- global_mean -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>global_mean</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"234.71\" cy=\"-198.43\" rx=\"74.91\" ry=\"37.45\"/>\n",
"<text text-anchor=\"middle\" x=\"234.71\" y=\"-209.73\" font-family=\"Times,serif\" font-size=\"14.00\">global_mean</text>\n",
"<text text-anchor=\"middle\" x=\"234.71\" y=\"-194.73\" font-family=\"Times,serif\" font-size=\"14.00\">~</text>\n",
"<text text-anchor=\"middle\" x=\"234.71\" y=\"-179.73\" font-family=\"Times,serif\" font-size=\"14.00\">Normal</text>\n",
"</g>\n",
"<!-- global_mean&#45;&gt;likelihood -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>global_mean&#45;&gt;likelihood</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M257.19,-162.41C265.72,-149.13 275.52,-133.85 284.45,-119.95\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"287.41,-121.81 289.87,-111.5 281.52,-118.03 287.41,-121.81\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7fdb7e426d70>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pm.model_to_graphviz(m)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Multiprocess sampling (4 chains in 4 jobs)\n",
"CompoundStep\n",
">Metropolis: [cp_times]\n",
">NUTS: [cp_sd, cp_deltas, global_mean, noise_sd]\n"
]
},
{
"data": {
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
" background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <progress value='8000' class='' max='8000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 100.00% [8000/8000 06:39&lt;00:00 Sampling 4 chains, 0 divergences]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling 4 chains for 1_500 tune and 500 draw iterations (6_000 + 2_000 draws total) took 400 seconds.\n"
]
}
],
"source": [
"with m:\n",
" trace = pm.sample(tune=1500, chains=4, draws=500, random_seed=rng)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"No divergences!!!"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mean</th>\n",
" <th>sd</th>\n",
" <th>hdi_3%</th>\n",
" <th>hdi_97%</th>\n",
" <th>mcse_mean</th>\n",
" <th>mcse_sd</th>\n",
" <th>ess_bulk</th>\n",
" <th>ess_tail</th>\n",
" <th>r_hat</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>cp_times[0]</th>\n",
" <td>24.256</td>\n",
" <td>17.229</td>\n",
" <td>1.000</td>\n",
" <td>47.000</td>\n",
" <td>8.301</td>\n",
" <td>6.320</td>\n",
" <td>5.0</td>\n",
" <td>11.0</td>\n",
" <td>2.04</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_times[1]</th>\n",
" <td>40.992</td>\n",
" <td>4.821</td>\n",
" <td>36.000</td>\n",
" <td>49.000</td>\n",
" <td>0.547</td>\n",
" <td>0.388</td>\n",
" <td>73.0</td>\n",
" <td>68.0</td>\n",
" <td>1.04</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_times[2]</th>\n",
" <td>22.595</td>\n",
" <td>12.367</td>\n",
" <td>9.000</td>\n",
" <td>48.000</td>\n",
" <td>5.797</td>\n",
" <td>4.407</td>\n",
" <td>6.0</td>\n",
" <td>6.0</td>\n",
" <td>2.18</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_times[3]</th>\n",
" <td>36.684</td>\n",
" <td>9.548</td>\n",
" <td>22.000</td>\n",
" <td>48.000</td>\n",
" <td>4.414</td>\n",
" <td>3.340</td>\n",
" <td>7.0</td>\n",
" <td>4.0</td>\n",
" <td>1.62</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_times[4]</th>\n",
" <td>40.497</td>\n",
" <td>5.281</td>\n",
" <td>35.000</td>\n",
" <td>49.000</td>\n",
" <td>0.660</td>\n",
" <td>0.469</td>\n",
" <td>67.0</td>\n",
" <td>45.0</td>\n",
" <td>1.03</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_times[5]</th>\n",
" <td>19.262</td>\n",
" <td>14.065</td>\n",
" <td>1.000</td>\n",
" <td>46.000</td>\n",
" <td>6.697</td>\n",
" <td>5.089</td>\n",
" <td>5.0</td>\n",
" <td>21.0</td>\n",
" <td>2.54</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_times[6]</th>\n",
" <td>33.946</td>\n",
" <td>15.123</td>\n",
" <td>4.000</td>\n",
" <td>49.000</td>\n",
" <td>7.146</td>\n",
" <td>5.425</td>\n",
" <td>7.0</td>\n",
" <td>6.0</td>\n",
" <td>1.63</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_times[7]</th>\n",
" <td>28.802</td>\n",
" <td>13.438</td>\n",
" <td>9.000</td>\n",
" <td>49.000</td>\n",
" <td>6.275</td>\n",
" <td>4.755</td>\n",
" <td>5.0</td>\n",
" <td>32.0</td>\n",
" <td>2.15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_times[8]</th>\n",
" <td>28.334</td>\n",
" <td>14.984</td>\n",
" <td>1.000</td>\n",
" <td>47.000</td>\n",
" <td>7.218</td>\n",
" <td>5.496</td>\n",
" <td>5.0</td>\n",
" <td>13.0</td>\n",
" <td>2.38</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_times[9]</th>\n",
" <td>14.959</td>\n",
" <td>13.127</td>\n",
" <td>1.000</td>\n",
" <td>43.000</td>\n",
" <td>6.027</td>\n",
" <td>4.555</td>\n",
" <td>5.0</td>\n",
" <td>10.0</td>\n",
" <td>2.38</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_deltas[0]</th>\n",
" <td>0.312</td>\n",
" <td>0.355</td>\n",
" <td>-0.256</td>\n",
" <td>0.848</td>\n",
" <td>0.074</td>\n",
" <td>0.053</td>\n",
" <td>24.0</td>\n",
" <td>269.0</td>\n",
" <td>1.13</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_deltas[1]</th>\n",
" <td>0.166</td>\n",
" <td>0.564</td>\n",
" <td>-0.841</td>\n",
" <td>0.893</td>\n",
" <td>0.193</td>\n",
" <td>0.141</td>\n",
" <td>10.0</td>\n",
" <td>25.0</td>\n",
" <td>1.34</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_deltas[2]</th>\n",
" <td>-0.070</td>\n",
" <td>0.611</td>\n",
" <td>-0.938</td>\n",
" <td>0.847</td>\n",
" <td>0.236</td>\n",
" <td>0.174</td>\n",
" <td>8.0</td>\n",
" <td>264.0</td>\n",
" <td>1.44</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_deltas[3]</th>\n",
" <td>-0.104</td>\n",
" <td>0.597</td>\n",
" <td>-1.078</td>\n",
" <td>0.738</td>\n",
" <td>0.242</td>\n",
" <td>0.180</td>\n",
" <td>7.0</td>\n",
" <td>63.0</td>\n",
" <td>1.54</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_deltas[4]</th>\n",
" <td>0.047</td>\n",
" <td>0.654</td>\n",
" <td>-1.079</td>\n",
" <td>1.142</td>\n",
" <td>0.186</td>\n",
" <td>0.135</td>\n",
" <td>15.0</td>\n",
" <td>82.0</td>\n",
" <td>1.27</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_deltas[5]</th>\n",
" <td>0.014</td>\n",
" <td>0.852</td>\n",
" <td>-1.105</td>\n",
" <td>1.840</td>\n",
" <td>0.233</td>\n",
" <td>0.169</td>\n",
" <td>14.0</td>\n",
" <td>167.0</td>\n",
" <td>1.22</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_deltas[6]</th>\n",
" <td>0.323</td>\n",
" <td>0.820</td>\n",
" <td>-1.325</td>\n",
" <td>1.933</td>\n",
" <td>0.033</td>\n",
" <td>0.059</td>\n",
" <td>558.0</td>\n",
" <td>746.0</td>\n",
" <td>1.23</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_deltas[7]</th>\n",
" <td>0.343</td>\n",
" <td>0.944</td>\n",
" <td>-1.410</td>\n",
" <td>2.192</td>\n",
" <td>0.028</td>\n",
" <td>0.027</td>\n",
" <td>1143.0</td>\n",
" <td>955.0</td>\n",
" <td>1.04</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_deltas[8]</th>\n",
" <td>0.370</td>\n",
" <td>1.024</td>\n",
" <td>-1.724</td>\n",
" <td>2.158</td>\n",
" <td>0.027</td>\n",
" <td>0.024</td>\n",
" <td>1456.0</td>\n",
" <td>1313.0</td>\n",
" <td>1.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_deltas[9]</th>\n",
" <td>0.341</td>\n",
" <td>1.013</td>\n",
" <td>-1.617</td>\n",
" <td>2.180</td>\n",
" <td>0.027</td>\n",
" <td>0.026</td>\n",
" <td>1431.0</td>\n",
" <td>801.0</td>\n",
" <td>1.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>global_mean</th>\n",
" <td>-0.298</td>\n",
" <td>0.100</td>\n",
" <td>-0.499</td>\n",
" <td>-0.125</td>\n",
" <td>0.010</td>\n",
" <td>0.007</td>\n",
" <td>90.0</td>\n",
" <td>274.0</td>\n",
" <td>1.04</td>\n",
" </tr>\n",
" <tr>\n",
" <th>cp_sd</th>\n",
" <td>0.336</td>\n",
" <td>0.253</td>\n",
" <td>0.000</td>\n",
" <td>0.763</td>\n",
" <td>0.008</td>\n",
" <td>0.006</td>\n",
" <td>784.0</td>\n",
" <td>647.0</td>\n",
" <td>1.01</td>\n",
" </tr>\n",
" <tr>\n",
" <th>noise_sd</th>\n",
" <td>0.145</td>\n",
" <td>0.019</td>\n",
" <td>0.111</td>\n",
" <td>0.181</td>\n",
" <td>0.002</td>\n",
" <td>0.002</td>\n",
" <td>48.0</td>\n",
" <td>57.0</td>\n",
" <td>1.06</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n",
"cp_times[0] 24.256 17.229 1.000 47.000 8.301 6.320 5.0 \n",
"cp_times[1] 40.992 4.821 36.000 49.000 0.547 0.388 73.0 \n",
"cp_times[2] 22.595 12.367 9.000 48.000 5.797 4.407 6.0 \n",
"cp_times[3] 36.684 9.548 22.000 48.000 4.414 3.340 7.0 \n",
"cp_times[4] 40.497 5.281 35.000 49.000 0.660 0.469 67.0 \n",
"cp_times[5] 19.262 14.065 1.000 46.000 6.697 5.089 5.0 \n",
"cp_times[6] 33.946 15.123 4.000 49.000 7.146 5.425 7.0 \n",
"cp_times[7] 28.802 13.438 9.000 49.000 6.275 4.755 5.0 \n",
"cp_times[8] 28.334 14.984 1.000 47.000 7.218 5.496 5.0 \n",
"cp_times[9] 14.959 13.127 1.000 43.000 6.027 4.555 5.0 \n",
"cp_deltas[0] 0.312 0.355 -0.256 0.848 0.074 0.053 24.0 \n",
"cp_deltas[1] 0.166 0.564 -0.841 0.893 0.193 0.141 10.0 \n",
"cp_deltas[2] -0.070 0.611 -0.938 0.847 0.236 0.174 8.0 \n",
"cp_deltas[3] -0.104 0.597 -1.078 0.738 0.242 0.180 7.0 \n",
"cp_deltas[4] 0.047 0.654 -1.079 1.142 0.186 0.135 15.0 \n",
"cp_deltas[5] 0.014 0.852 -1.105 1.840 0.233 0.169 14.0 \n",
"cp_deltas[6] 0.323 0.820 -1.325 1.933 0.033 0.059 558.0 \n",
"cp_deltas[7] 0.343 0.944 -1.410 2.192 0.028 0.027 1143.0 \n",
"cp_deltas[8] 0.370 1.024 -1.724 2.158 0.027 0.024 1456.0 \n",
"cp_deltas[9] 0.341 1.013 -1.617 2.180 0.027 0.026 1431.0 \n",
"global_mean -0.298 0.100 -0.499 -0.125 0.010 0.007 90.0 \n",
"cp_sd 0.336 0.253 0.000 0.763 0.008 0.006 784.0 \n",
"noise_sd 0.145 0.019 0.111 0.181 0.002 0.002 48.0 \n",
"\n",
" ess_tail r_hat \n",
"cp_times[0] 11.0 2.04 \n",
"cp_times[1] 68.0 1.04 \n",
"cp_times[2] 6.0 2.18 \n",
"cp_times[3] 4.0 1.62 \n",
"cp_times[4] 45.0 1.03 \n",
"cp_times[5] 21.0 2.54 \n",
"cp_times[6] 6.0 1.63 \n",
"cp_times[7] 32.0 2.15 \n",
"cp_times[8] 13.0 2.38 \n",
"cp_times[9] 10.0 2.38 \n",
"cp_deltas[0] 269.0 1.13 \n",
"cp_deltas[1] 25.0 1.34 \n",
"cp_deltas[2] 264.0 1.44 \n",
"cp_deltas[3] 63.0 1.54 \n",
"cp_deltas[4] 82.0 1.27 \n",
"cp_deltas[5] 167.0 1.22 \n",
"cp_deltas[6] 746.0 1.23 \n",
"cp_deltas[7] 955.0 1.04 \n",
"cp_deltas[8] 1313.0 1.00 \n",
"cp_deltas[9] 801.0 1.00 \n",
"global_mean 274.0 1.04 \n",
"cp_sd 647.0 1.01 \n",
"noise_sd 57.0 1.06 "
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"az.summary(trace)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"R-hat is awful, but we don't care about those changepoint specific variables that correspond to implausible number of changepoints. The global parameters make sense at least:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1200x400 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"az.plot_posterior(\n",
" trace, var_names=[\"noise_sd\", \"cp_sd\", \"global_mean\"], \n",
" ref_val=[noise_sd_true, cp_deltas_true.std(), noiseless[0]],\n",
" figsize=(12, 4)\n",
");"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Recovering the marginalized variable and checking inference"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's recover the marginalized variable `n_cps`, so that we can see what he model actually predicts for the mean and changepoints"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"with pm.Model() as recover_m:\n",
" n_cps = pm.DiscreteUniform(\"n_cps\", 0, max_cp_inference)\n",
" cp_times = pm.DiscreteUniform(\"cp_times\", 1, T-1, shape=max_cp_inference)\n",
"\n",
" # We need to disable transforms so that we can use the posterior points directly\n",
" # As the transformed draws are not saved\n",
" cp_sd = pm.HalfNormal('cp_sd', sigma=2, transform=None)\n",
" cp_deltas = pm.Normal('cp_deltas', cp_sd, shape=max_cp_inference)\n",
" \n",
" global_mean = pm.Normal('global_mean', sigma=1)\n",
" noise_sd = pm.HalfNormal('noise_sd', sigma=1, transform=None)\n",
" \n",
" cp_times_sorted = cp_times.sort()\n",
" is_timestep_past_cp = (tiled_times >= cp_times_sorted[None, :].repeat(T, axis=0))\n",
" cp_contrib = at.sum(cp_deltas[:n_cps] * is_timestep_past_cp[:, :n_cps], axis=1)\n",
" \n",
" mu = pm.Deterministic(\"mu\", global_mean + cp_contrib)\n",
" \n",
" llike = pm.Normal('likelihood', mu=mu, sigma=noise_sd, observed=xs)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# Only the likelihood depends directly on n_cps, so we create a logp function that only considers those two\n",
"logp_fn = recover_m.compile_fn(\n",
" recover_m.logp([n_cps, llike]), \n",
" inputs=recover_m.value_vars,\n",
" on_unused_input=\"ignore\",\n",
" point_fn=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2000"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from pymc.util import dataset_to_point_list\n",
"points, _ = dataset_to_point_list(az.extract(trace), sample_dims=(\"sample\",))\n",
"len(points)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([-259.7478298 , -138.74512393, -123.36547779, -291.10229905,\n",
" -394.30545815, -31.77460832, -60.62234883, -91.30291365,\n",
" 20.99261401, 20.8252028 , 19.37766871])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# logp for first point\n",
"logps = np.array([logp_fn(**points[0], n_cps=n_cps) for n_cps in range(0, max_cp_inference + 1)])\n",
"logps"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"8"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# For simplicity we will take the argmax, instead of sampling from the logps\n",
"np.argmax(logps)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The most likely `n_cps` for the first draw is 4. Let's now recover the most likely for each point in the posterior"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6.0085"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"possible_n_cps = list(range(max_cp_inference+1))\n",
"n_cps = np.array([\n",
" np.argmax([\n",
" [logp_fn(**point, n_cps=n_cps) for n_cps in possible_n_cps]\n",
" \n",
" ])\n",
" for point in points\n",
"])\n",
"n_cps.mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The posterior mean overshoots the real number of changepoints. I would blame it on the prior?"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"az.plot_posterior(n_cps, ref_val=n_cps_true);"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"post = az.extract(trace)\n",
"post[\"n_cps\"] = xr.DataArray(n_cps, dims=\"sample\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With the `n_cps` recovered we can now also recover any other deterministics of interest, by exploiting `posterior_predictive`"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Sampling: []\n"
]
},
{
"data": {
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
" background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <progress value='2000' class='' max='2000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 100.00% [2000/2000 00:00&lt;00:00]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"with recover_m:\n",
" mu = pm.sample_posterior_predictive(\n",
" post, sample_dims=[\"sample\"], \n",
" var_names=[\"mu\"],\n",
" random_seed=rng,\n",
" ).posterior_predictive[\"mu\"]\n",
"mu = xr.DataArray(mu.values, dims=(\"sample\", \"mu_dim_0\"))\n",
"post[\"mu\"] = mu "
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"def make_inference_plot(trace, xs, noiseless, true_cp, *, figsize=(9,4), n_cp_shown=10):\n",
" T = len(xs)\n",
" top_cp = Counter(\n",
" trace[\"cp_times\"].to_numpy().astype(int).ravel().tolist()\n",
" ).most_common(n_cp_shown)\n",
"\n",
" plt.figure(figsize=figsize)\n",
" plt.plot(noiseless, label='True noiseless values', color='green')\n",
"\n",
" plt.plot(trace['mu'].mean(axis=(0,)), label='Inferred noiseless mean', color='orange')\n",
"\n",
" q10, q90 = np.percentile(trace['mu'], [10,90], axis=(0,))\n",
" plt.fill_between(np.arange(T), q10, q90, color='orange', alpha=0.2)\n",
" plt.plot(xs, linestyle='', color='k', marker='o', label='Observed data')\n",
" for i, cp in enumerate(true_cp):\n",
" if i == 0:\n",
" label = 'True change point'\n",
" else:\n",
" label=None\n",
" plt.axvline(cp, color='g', linestyle='--',label=label)\n",
"\n",
" for i, (t, _) in enumerate(top_cp):\n",
" if i == 0:\n",
" label = 'Inferred change point'\n",
" else:\n",
" label=None\n",
" # The recovered change points seem to be one unit to the left\n",
" # Probably I messed up something in the model specification\n",
" plt.axvline(t+0.1, color='orange', linestyle='--', label=label)\n",
" plt.xlabel('Timestep',fontsize=12)\n",
" plt.ylabel('$X(t)$',fontsize=18)\n",
" plt.legend(loc=(1, .6));"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 900x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"make_inference_plot(post, xs, noiseless, cp_times_true, n_cp_shown=6)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"provenance": []
},
"hide_input": false,
"kernelspec": {
"display_name": "pymcx",
"language": "python",
"name": "pymcx"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment