Created
March 7, 2020 17:17
-
-
Save fehiepsi/8bc0cd43289a60e79c176d73912a7fb2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Forecasting I: univariate, heavy tailed\n", | |
"\n", | |
"This tutorial introduces the [pyro.contrib.forecast](http://docs.pyro.ai/en/latest/contrib.forecast.html) module, a framework for forecasting with Pyro models. This tutorial covers only univariate models and simple likelihoods. This tutorial assumes the reader is already familiar with [SVI](http://pyro.ai/examples/svi_part_ii.html) and [tensor shapes](http://pyro.ai/examples/tensor_shapes.html).\n", | |
"\n", | |
"#### Summary\n", | |
"\n", | |
"- To create a forecasting model:\n", | |
" 1. Create a subclass of the [ForecastingModel](http://docs.pyro.ai/en/latest/contrib.forecast.html#pyro.contrib.forecast.forecaster.ForecastingModel) class.\n", | |
" 2. Implement the [.model(zero_data, covariates)](http://docs.pyro.ai/en/latest/contrib.forecast.html#pyro.contrib.forecast.forecaster.ForecastingModel.model) method using standard Pyro syntax.\n", | |
" 3. Sample all time-local variables inside the [self.time_plate](http://docs.pyro.ai/en/latest/contrib.forecast.html#pyro.contrib.forecast.forecaster.ForecastingModel.time_plate) context.\n", | |
" 4. Finally call the [.predict(noise_dist, prediction)](http://docs.pyro.ai/en/latest/contrib.forecast.html#pyro.contrib.forecast.forecaster.ForecastingModel.predict) method.\n", | |
"- To train a forecasting model, create a [Forecaster](http://docs.pyro.ai/en/latest/contrib.forecast.html#pyro.contrib.forecast.forecaster.Forecaster) object.\n", | |
" - Training can be flaky, you'll need to tune hyperparameters and randomly restart.\n", | |
" - Reparameterization can help learning, e.g. [LocScaleReparam](http://docs.pyro.ai/en/latest/infer.reparam.html#pyro.infer.reparam.loc_scale.LocScaleReparam).\n", | |
"- To forecast the future, draw samples from a `Forecaster` object conditioned on data and covariates.\n", | |
"- To model seasonality, use helpers [periodic_features()](http://docs.pyro.ai/en/latest/ops.html#pyro.ops.tensor_utils.periodic_features), [periodic_repeat()](http://docs.pyro.ai/en/latest/ops.html#pyro.ops.tensor_utils.periodic_repeat), and [periodic_cumsum()](http://docs.pyro.ai/en/latest/ops.html#pyro.ops.tensor_utils.periodic_cumsum).\n", | |
"- To model heavy-tailed data, use [Stable](http://docs.pyro.ai/en/latest/distributions.html#stable) distributions and [StableReparam](http://docs.pyro.ai/en/latest/infer.reparam.html#pyro.infer.reparam.stable.StableReparam).\n", | |
"- To evaluate results, use the [backtest()](http://docs.pyro.ai/en/latest/contrib.forecast.html#pyro.contrib.forecast.evaluate.eval_crps) helper or low-level loss functions." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import pyro\n", | |
"import pyro.distributions as dist\n", | |
"import pyro.poutine as poutine\n", | |
"from pyro.contrib.examples.bart import load_bart_od\n", | |
"from pyro.contrib.forecast import ForecastingModel, HMCForecaster, backtest, eval_crps\n", | |
"from pyro.infer.reparam import LocScaleReparam, StableReparam\n", | |
"from pyro.ops.tensor_utils import periodic_cumsum, periodic_repeat, periodic_features\n", | |
"from pyro.ops.stats import quantile\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"%matplotlib inline\n", | |
"assert pyro.__version__.startswith('1.2.1')\n", | |
"pyro.enable_validation(True)\n", | |
"pyro.set_rng_seed(20200221)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"dict_keys(['stations', 'start_date', 'counts'])\n", | |
"torch.Size([78888, 50, 50])\n", | |
"12TH 16TH 19TH 24TH ANTC ASHB BALB BAYF BERY CAST CIVC COLM COLS CONC DALY DBRK DELN DUBL EMBR FRMT FTVL GLEN HAYW LAFY LAKE MCAR MLBR MLPT MONT NBRK NCON OAKL ORIN PCTR PHIL PITT PLZA POWL RICH ROCK SANL SBRN SFIA SHAY SSAN UCTY WARM WCRK WDUB WOAK\n" | |
] | |
} | |
], | |
"source": [ | |
"dataset = load_bart_od()\n", | |
"print(dataset.keys())\n", | |
"print(dataset[\"counts\"].shape)\n", | |
"print(\" \".join(dataset[\"stations\"]))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Intro to Pyro's forecasting framework\n", | |
"\n", | |
"Pyro's forecasting framework consists of:\n", | |
"- a [ForecastingModel](http://docs.pyro.ai/en/latest/contrib.forecast.html#pyro.contrib.forecast.forecaster.ForecastingModel) base class, whose ``.model()`` method can be implemented for custom forecasting models,\n", | |
"- a [Forecaster](http://docs.pyro.ai/en/latest/contrib.forecast.html#pyro.contrib.forecast.forecaster.Forecaster) class that trains and forecasts using ``ForecastingModel``s, and\n", | |
"- a [backtest()](http://docs.pyro.ai/en/latest/contrib.forecast.html#pyro.contrib.forecast.evaluate.backtest) helper to evaluate models on a number of metrics.\n", | |
"\n", | |
"Consider a simple univariate dataset, say weekly [BART train](https://www.bart.gov/about/reports/ridership) ridership aggregated over all stations in the network. This data roughly logarithmic, so we log-transform for modeling." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 648x216 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"T, O, D = dataset[\"counts\"].shape\n", | |
"data = dataset[\"counts\"][:T // (24 * 7) * 24 * 7].reshape(T // (24 * 7), -1).sum(-1).log()\n", | |
"data = data.unsqueeze(-1)\n", | |
"plt.figure(figsize=(9, 3))\n", | |
"plt.plot(data)\n", | |
"plt.title(\"Total weekly ridership\")\n", | |
"plt.ylabel(\"log(# rides)\")\n", | |
"plt.xlabel(\"Week after 2011-01-01\")\n", | |
"plt.xlim(0, len(data));" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's start with a simple log-linear regression model, with no trend or seasonality. Note that while this example is univariate, Pyro's forecasting framework is multivariate, so we'll often need to reshape using `.unsqueeze(-1)`, `.expand([1])`, and `.to_event(1)`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# First we need some boilerplate to create a class and define a .model() method.\n", | |
"class Model1(ForecastingModel):\n", | |
" # We then implement the .model() method. Since this is a generative model, it shouldn't\n", | |
" # look at data; however it is convenient to see the shape of data we're supposed to\n", | |
" # generate, so this inputs a zeros_like(data) tensor instead of the actual data.\n", | |
" def model(self, zero_data, covariates):\n", | |
" data_dim = zero_data.size(-1) # Should be 1 in this univariate tutorial.\n", | |
" feature_dim = covariates.size(-1)\n", | |
"\n", | |
" # The first part of the model is a probabilistic program to create a prediction.\n", | |
" # We use the zero_data as a template for the shape of the prediction.\n", | |
" bias = pyro.sample(\"bias\", dist.Normal(0, 10).expand([data_dim]).to_event(1))\n", | |
" weight = pyro.sample(\"weight\", dist.Normal(0, 0.1).expand([feature_dim]).to_event(1))\n", | |
" prediction = bias + (weight * covariates).sum(-1, keepdim=True)\n", | |
" # The prediction should have the same shape as zero_data (duration, obs_dim),\n", | |
" # but may have additional sample dimensions on the left.\n", | |
" assert prediction.shape[-2:] == zero_data.shape\n", | |
"\n", | |
" # The next part of the model creates a likelihood or noise distribution.\n", | |
" # Again we'll be Bayesian and write this as a probabilistic program with\n", | |
" # priors over parameters, and again we'll use zero_data as a noise template.\n", | |
" noise_scale = pyro.sample(\"noise_scale\", dist.LogNormal(-5, 5).expand([1]).to_event(1))\n", | |
" noise_dist = dist.Normal(zero_data, noise_scale)\n", | |
"\n", | |
" # The final step is to call the .predict() method.\n", | |
" self.predict(noise_dist, prediction)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We can now train this model by creating a [Forecaster](http://docs.pyro.ai/en/latest/contrib.forecast.html#pyro.contrib.forecast.forecaster.Forecaster) object. We'll split the data into `[T0,T1)` for training and `[T1,T2)` for testing." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"T0 = 0 # begining\n", | |
"T2 = data.size(-2) # end\n", | |
"T1 = T2 - 52 # train/test split" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Sample: 100%|██████████| 2000/2000 [00:39, 50.90it/s, step size=3.95e-01, acc. prob=0.930]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
" mean std median 5.0% 95.0% n_eff r_hat\n", | |
" bias[0] 14.58 0.01 14.58 14.56 14.59 501.54 1.00\n", | |
" weight[0] 0.12 0.02 0.12 0.08 0.14 513.97 1.00\n", | |
"noise_scale[0] 0.13 0.00 0.13 0.12 0.13 592.00 1.00\n", | |
"\n", | |
"Number of divergences: 0\n", | |
"CPU times: user 39.4 s, sys: 194 ms, total: 39.6 s\n", | |
"Wall time: 39.3 s\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"pyro.set_rng_seed(1)\n", | |
"pyro.clear_param_store()\n", | |
"time = torch.arange(float(T2)) / 365\n", | |
"covariates = torch.stack([time], dim=-1)\n", | |
"forecaster = HMCForecaster(Model1(), data[:T1], covariates[:T1])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Next we can evaluate by drawing posterior samples from the forecaster, passing in full covariates but only partial data. We'll use Pyro's [quantile()](http://docs.pyro.ai/en/latest/ops.html#pyro.ops.stats.quantile) function to plot median and an 80% confidence interval. To evaluate fit we'll use [eval_crps()](http://docs.pyro.ai/en/latest/contrib.forecast.html#pyro.contrib.forecast.evaluate.eval_crps) to compute [Continuous Ranked Probability Score](https://www.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pdf); this is an good metric to assess distributional fit of a heavy-tailed distribution." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"torch.Size([1000, 52, 1]) torch.Size([52])\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 648x216 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"samples = forecaster(data[:T1], covariates, num_samples=1000)\n", | |
"p10, p50, p90 = quantile(samples, (0.1, 0.5, 0.9)).squeeze(-1)\n", | |
"crps = eval_crps(samples, data[T1:])\n", | |
"print(samples.shape, p10.shape)\n", | |
"\n", | |
"plt.figure(figsize=(9, 3))\n", | |
"plt.fill_between(torch.arange(T1, T2), p10, p90, color=\"red\", alpha=0.3)\n", | |
"plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')\n", | |
"plt.plot(data, 'k-', label='truth')\n", | |
"plt.title(\"Total weekly ridership (CRPS = {:0.3g})\".format(crps))\n", | |
"plt.ylabel(\"log(# rides)\")\n", | |
"plt.xlabel(\"Week after 2011-01-01\")\n", | |
"plt.xlim(0, None)\n", | |
"plt.legend(loc=\"best\");" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Zooming in to just the forecasted region, we see this model ignores seasonal behavior." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 648x216 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.figure(figsize=(9, 3))\n", | |
"plt.fill_between(torch.arange(T1, T2), p10, p90, color=\"red\", alpha=0.3)\n", | |
"plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')\n", | |
"plt.plot(torch.arange(T1, T2), data[T1:], 'k-', label='truth')\n", | |
"plt.title(\"Total weekly ridership (CRPS = {:0.3g})\".format(crps))\n", | |
"plt.ylabel(\"log(# rides)\")\n", | |
"plt.xlabel(\"Week after 2011-01-01\")\n", | |
"plt.xlim(T1, None)\n", | |
"plt.legend(loc=\"best\");" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We could add a yearly seasonal component simply by adding new covariates (note we've already taken care in the model to handle `feature_dim > 1`)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Sample: 100%|██████████| 2000/2000 [00:58, 34.04it/s, step size=3.49e-01, acc. prob=0.895]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
" mean std median 5.0% 95.0% n_eff r_hat\n", | |
" bias[0] 14.57 0.01 14.57 14.56 14.59 1531.06 1.00\n", | |
" weight[0] 0.12 0.01 0.12 0.10 0.15 1460.82 1.00\n", | |
" weight[1] -0.04 0.01 -0.04 -0.05 -0.03 1862.32 1.00\n", | |
" weight[2] -0.05 0.01 -0.05 -0.06 -0.04 1810.56 1.00\n", | |
" weight[3] -0.01 0.01 -0.01 -0.02 -0.00 1462.10 1.00\n", | |
" weight[4] -0.02 0.01 -0.02 -0.03 -0.01 2460.94 1.00\n", | |
" weight[5] -0.02 0.01 -0.02 -0.03 -0.01 1898.95 1.00\n", | |
" weight[6] -0.03 0.01 -0.03 -0.04 -0.02 1444.50 1.00\n", | |
" weight[7] -0.01 0.01 -0.01 -0.02 -0.00 2423.26 1.00\n", | |
" weight[8] -0.04 0.01 -0.04 -0.05 -0.03 1654.61 1.00\n", | |
" weight[9] -0.02 0.01 -0.02 -0.03 -0.01 1878.95 1.00\n", | |
" weight[10] -0.03 0.01 -0.03 -0.04 -0.02 2050.28 1.00\n", | |
" weight[11] 0.03 0.01 0.03 0.02 0.04 1454.76 1.00\n", | |
" weight[12] 0.00 0.01 0.00 -0.01 0.01 1848.83 1.00\n", | |
" weight[13] 0.03 0.01 0.03 0.02 0.04 2235.39 1.00\n", | |
" weight[14] 0.01 0.01 0.01 -0.00 0.02 1873.06 1.00\n", | |
" weight[15] 0.01 0.01 0.01 -0.00 0.02 1995.56 1.00\n", | |
" weight[16] 0.00 0.01 0.00 -0.01 0.01 1757.97 1.00\n", | |
" weight[17] 0.01 0.01 0.01 0.01 0.02 1717.18 1.00\n", | |
" weight[18] -0.01 0.01 -0.01 -0.02 0.00 1986.53 1.00\n", | |
" weight[19] 0.02 0.01 0.02 0.01 0.03 1626.81 1.00\n", | |
" weight[20] 0.01 0.01 0.01 -0.00 0.02 2021.86 1.00\n", | |
" weight[21] 0.03 0.01 0.03 0.01 0.04 1716.86 1.00\n", | |
" weight[22] 0.01 0.01 0.01 0.00 0.02 1506.92 1.00\n", | |
" weight[23] 0.02 0.01 0.02 0.00 0.03 1866.89 1.00\n", | |
" weight[24] -0.01 0.01 -0.01 -0.02 0.00 1733.13 1.00\n", | |
" weight[25] -0.00 0.01 -0.00 -0.01 0.01 1551.61 1.00\n", | |
" weight[26] -0.01 0.01 -0.01 -0.02 0.00 1792.28 1.00\n", | |
" weight[27] 0.00 0.01 0.00 -0.01 0.01 1563.69 1.00\n", | |
" weight[28] -0.02 0.01 -0.02 -0.03 -0.01 1930.02 1.00\n", | |
" weight[29] -0.02 0.01 -0.02 -0.03 -0.01 1721.07 1.00\n", | |
" weight[30] -0.01 0.01 -0.01 -0.03 -0.00 1956.95 1.00\n", | |
" weight[31] 0.00 0.01 0.00 -0.01 0.01 2298.21 1.00\n", | |
" weight[32] -0.00 0.01 -0.00 -0.01 0.01 1698.65 1.00\n", | |
" weight[33] -0.02 0.01 -0.02 -0.03 -0.01 1480.71 1.00\n", | |
" weight[34] -0.00 0.01 -0.00 -0.01 0.01 1725.80 1.00\n", | |
" weight[35] -0.02 0.01 -0.02 -0.03 -0.01 1622.77 1.00\n", | |
" weight[36] -0.03 0.01 -0.03 -0.04 -0.02 1649.00 1.00\n", | |
" weight[37] -0.02 0.01 -0.02 -0.03 -0.01 1941.26 1.00\n", | |
" weight[38] -0.03 0.01 -0.03 -0.04 -0.02 1465.99 1.00\n", | |
" weight[39] -0.01 0.01 -0.01 -0.02 -0.00 1757.78 1.00\n", | |
" weight[40] -0.01 0.01 -0.01 -0.02 -0.00 2271.25 1.00\n", | |
" weight[41] -0.01 0.01 -0.01 -0.02 0.00 1805.23 1.00\n", | |
" weight[42] -0.02 0.01 -0.02 -0.03 -0.00 1604.31 1.00\n", | |
" weight[43] -0.00 0.01 -0.00 -0.01 0.01 1767.37 1.00\n", | |
" weight[44] -0.01 0.01 -0.01 -0.02 0.00 1680.94 1.00\n", | |
" weight[45] -0.01 0.01 -0.01 -0.02 -0.00 1844.59 1.00\n", | |
" weight[46] -0.02 0.01 -0.02 -0.03 -0.01 1804.98 1.00\n", | |
" weight[47] 0.00 0.01 0.00 -0.01 0.01 1516.20 1.00\n", | |
" weight[48] -0.02 0.01 -0.02 -0.02 -0.00 1346.75 1.00\n", | |
" weight[49] 0.02 0.01 0.02 0.01 0.03 1521.11 1.00\n", | |
" weight[50] 0.01 0.01 0.01 -0.00 0.02 1747.17 1.00\n", | |
" weight[51] 0.01 0.01 0.01 0.00 0.02 1648.51 1.00\n", | |
" weight[52] 0.00 0.01 0.00 -0.01 0.01 1628.19 1.00\n", | |
"noise_scale[0] 0.09 0.00 0.09 0.08 0.10 1054.48 1.00\n", | |
"\n", | |
"Number of divergences: 0\n", | |
"CPU times: user 1min, sys: 345 ms, total: 1min 1s\n", | |
"Wall time: 58.8 s\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"pyro.set_rng_seed(1)\n", | |
"pyro.clear_param_store()\n", | |
"time = torch.arange(float(T2)) / 365\n", | |
"covariates = torch.cat([time.unsqueeze(-1),\n", | |
" periodic_features(T2, 365.25 / 7)], dim=-1)\n", | |
"forecaster = HMCForecaster(Model1(), data[:T1], covariates[:T1])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 648x216 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"samples = forecaster(data[:T1], covariates, num_samples=1000)\n", | |
"p10, p50, p90 = quantile(samples, (0.1, 0.5, 0.9)).squeeze(-1)\n", | |
"crps = eval_crps(samples, data[T1:])\n", | |
"\n", | |
"plt.figure(figsize=(9, 3))\n", | |
"plt.fill_between(torch.arange(T1, T2), p10, p90, color=\"red\", alpha=0.3)\n", | |
"plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')\n", | |
"plt.plot(data, 'k-', label='truth')\n", | |
"plt.title(\"Total weekly ridership (CRPS = {:0.3g})\".format(crps))\n", | |
"plt.ylabel(\"log(# rides)\")\n", | |
"plt.xlabel(\"Week after 2011-01-01\")\n", | |
"plt.xlim(0, None)\n", | |
"plt.legend(loc=\"best\");" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 648x216 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.figure(figsize=(9, 3))\n", | |
"plt.fill_between(torch.arange(T1, T2), p10, p90, color=\"red\", alpha=0.3)\n", | |
"plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')\n", | |
"plt.plot(torch.arange(T1, T2), data[T1:], 'k-', label='truth')\n", | |
"plt.title(\"Total weekly ridership (CRPS = {:0.3g})\".format(crps))\n", | |
"plt.ylabel(\"log(# rides)\")\n", | |
"plt.xlabel(\"Week after 2011-01-01\")\n", | |
"plt.xlim(T1, None)\n", | |
"plt.legend(loc=\"best\");" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Time-local random variables: `self.time_plate`\n", | |
"\n", | |
"So far we've seen the ``ForecastingModel.model()`` method and ``self.predict()``. The last piece of forecasting-specific syntax is the ``self.time_plate`` context for time-local variables. To see how this works, consider changing our global linear trend model above to a local level model. Note the [poutine.reparam()](http://docs.pyro.ai/en/latest/poutine.html#pyro.poutine.handlers.reparam) handler is a general Pyro inference trick, not specific to forecasting." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Model2(ForecastingModel):\n", | |
" def model(self, zero_data, covariates):\n", | |
" data_dim = zero_data.size(-1)\n", | |
" feature_dim = covariates.size(-1)\n", | |
" bias = pyro.sample(\"bias\", dist.Normal(0, 10).expand([data_dim]).to_event(1))\n", | |
" weight = pyro.sample(\"weight\", dist.Normal(0, 0.1).expand([feature_dim]).to_event(1))\n", | |
"\n", | |
" # We'll sample a time-global scale parameter outside the time plate,\n", | |
" # then time-local iid noise inside the time plate.\n", | |
" drift_scale = pyro.sample(\"drift_scale\",\n", | |
" dist.LogNormal(-20, 5).expand([1]).to_event(1))\n", | |
" with self.time_plate:\n", | |
" # We'll use a reparameterizer to improve variational fit. The model would still be\n", | |
" # correct if you removed this context manager, but the fit appears to be worse.\n", | |
" with poutine.reparam(config={\"drift\": LocScaleReparam()}):\n", | |
" drift = pyro.sample(\"drift\", dist.Normal(zero_data, drift_scale).to_event(1))\n", | |
"\n", | |
" # After we sample the iid \"drift\" noise we can combine it in any time-dependent way.\n", | |
" # It is important to keep everything inside the plate independent and apply dependent\n", | |
" # transforms outside the plate.\n", | |
" motion = drift.cumsum(-2) # A Brownian motion.\n", | |
" \n", | |
" # The prediction now includes three terms.\n", | |
" prediction = motion + bias + (weight * covariates).sum(-1, keepdim=True)\n", | |
" assert prediction.shape[-2:] == zero_data.shape\n", | |
" \n", | |
" # Construct the noise distribution and predict.\n", | |
" noise_scale = pyro.sample(\"noise_scale\", dist.LogNormal(-5, 5).expand([1]).to_event(1))\n", | |
" noise_dist = dist.Normal(zero_data, noise_scale)\n", | |
" self.predict(noise_dist, prediction) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Sample: 100%|██████████| 2000/2000 [14:44, 2.26it/s, step size=4.47e-02, acc. prob=0.941]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
" mean std median 5.0% 95.0% n_eff r_hat\n", | |
" bias[0] 14.51 0.02 14.51 14.47 14.55 1397.54 1.00\n", | |
" weight[0] -0.04 0.01 -0.04 -0.05 -0.03 1644.47 1.00\n", | |
" weight[1] -0.05 0.01 -0.05 -0.06 -0.04 1449.86 1.00\n", | |
" weight[2] -0.01 0.01 -0.01 -0.02 -0.01 1394.49 1.00\n", | |
" weight[3] -0.02 0.01 -0.02 -0.03 -0.01 1885.88 1.00\n", | |
" weight[4] -0.02 0.01 -0.02 -0.03 -0.01 2369.42 1.00\n", | |
" weight[5] -0.03 0.01 -0.03 -0.04 -0.02 1371.17 1.00\n", | |
" weight[6] -0.01 0.00 -0.01 -0.02 -0.00 1753.90 1.00\n", | |
" weight[7] -0.04 0.01 -0.04 -0.05 -0.03 1948.30 1.00\n", | |
" weight[8] -0.02 0.01 -0.02 -0.03 -0.01 1642.62 1.00\n", | |
" weight[9] -0.03 0.01 -0.03 -0.03 -0.02 1529.05 1.00\n", | |
" weight[10] 0.03 0.01 0.03 0.02 0.03 1447.48 1.00\n", | |
" weight[11] 0.00 0.01 0.00 -0.01 0.01 1756.16 1.00\n", | |
" weight[12] 0.03 0.01 0.03 0.02 0.04 1767.69 1.00\n", | |
" weight[13] 0.01 0.01 0.01 0.00 0.02 2112.39 1.00\n", | |
" weight[14] 0.01 0.01 0.01 0.00 0.02 1763.25 1.00\n", | |
" weight[15] 0.00 0.01 0.00 -0.01 0.01 1121.05 1.00\n", | |
" weight[16] 0.01 0.01 0.01 0.01 0.02 1447.13 1.00\n", | |
" weight[17] -0.01 0.01 -0.01 -0.02 -0.00 1427.15 1.00\n", | |
" weight[18] 0.02 0.01 0.02 0.01 0.03 1717.20 1.00\n", | |
" weight[19] 0.01 0.01 0.01 -0.00 0.02 1336.62 1.00\n", | |
" weight[20] 0.03 0.01 0.03 0.02 0.03 1628.45 1.00\n", | |
" weight[21] 0.01 0.01 0.01 0.00 0.02 2339.68 1.00\n", | |
" weight[22] 0.02 0.01 0.02 0.01 0.02 1490.59 1.00\n", | |
" weight[23] -0.01 0.01 -0.01 -0.01 0.00 2222.28 1.00\n", | |
" weight[24] -0.00 0.01 -0.00 -0.01 0.01 1604.18 1.00\n", | |
" weight[25] -0.01 0.01 -0.01 -0.02 0.00 1914.97 1.00\n", | |
" weight[26] 0.00 0.01 0.00 -0.01 0.01 1952.83 1.00\n", | |
" weight[27] -0.02 0.01 -0.02 -0.03 -0.01 1308.91 1.00\n", | |
" weight[28] -0.02 0.01 -0.02 -0.03 -0.01 1954.75 1.00\n", | |
" weight[29] -0.02 0.01 -0.02 -0.02 -0.01 1768.12 1.00\n", | |
" weight[30] 0.00 0.01 0.00 -0.01 0.01 1879.60 1.00\n", | |
" weight[31] -0.00 0.01 -0.00 -0.01 0.00 2071.55 1.00\n", | |
" weight[32] -0.02 0.01 -0.02 -0.02 -0.01 2316.72 1.00\n", | |
" weight[33] -0.00 0.01 -0.00 -0.01 0.01 1986.63 1.00\n", | |
" weight[34] -0.02 0.01 -0.02 -0.02 -0.01 1692.31 1.00\n", | |
" weight[35] -0.03 0.01 -0.03 -0.04 -0.02 2400.52 1.00\n", | |
" weight[36] -0.02 0.01 -0.02 -0.03 -0.01 2018.29 1.00\n", | |
" weight[37] -0.03 0.01 -0.03 -0.04 -0.02 1550.45 1.00\n", | |
" weight[38] -0.01 0.01 -0.01 -0.02 -0.00 1918.60 1.00\n", | |
" weight[39] -0.01 0.01 -0.01 -0.02 -0.01 1850.17 1.00\n", | |
" weight[40] -0.01 0.01 -0.01 -0.02 0.00 1789.31 1.00\n", | |
" weight[41] -0.02 0.01 -0.02 -0.03 -0.01 2123.44 1.00\n", | |
" weight[42] -0.00 0.01 -0.00 -0.01 0.01 2158.39 1.00\n", | |
" weight[43] -0.01 0.01 -0.01 -0.02 -0.00 1422.64 1.00\n", | |
" weight[44] -0.01 0.01 -0.01 -0.02 -0.01 1863.10 1.00\n", | |
" weight[45] -0.02 0.01 -0.02 -0.03 -0.01 1723.12 1.00\n", | |
" weight[46] 0.00 0.01 0.00 -0.01 0.01 1777.51 1.00\n", | |
" weight[47] -0.02 0.01 -0.02 -0.02 -0.01 2065.29 1.00\n", | |
" weight[48] 0.02 0.01 0.02 0.01 0.03 1472.40 1.00\n", | |
" weight[49] 0.01 0.01 0.01 -0.00 0.02 1637.83 1.00\n", | |
" weight[50] 0.01 0.01 0.01 0.00 0.02 2110.91 1.00\n", | |
" weight[51] 0.00 0.01 0.00 -0.01 0.01 1429.17 1.00\n", | |
" drift_scale[0] 0.01 0.00 0.01 0.00 0.01 56.49 1.00\n", | |
" drift_decentered[0,0] -0.00 0.08 -0.00 -0.13 0.12 2280.56 1.00\n", | |
" drift_decentered[1,0] -0.00 0.08 -0.00 -0.12 0.13 2149.62 1.00\n", | |
" drift_decentered[2,0] -0.00 0.08 -0.00 -0.13 0.13 2960.05 1.00\n", | |
" drift_decentered[3,0] 0.00 0.08 -0.00 -0.12 0.12 1749.26 1.00\n", | |
" drift_decentered[4,0] 0.00 0.08 0.00 -0.12 0.12 1884.39 1.00\n", | |
" drift_decentered[5,0] 0.00 0.07 0.00 -0.11 0.12 1982.12 1.00\n", | |
" drift_decentered[6,0] 0.00 0.08 0.00 -0.14 0.12 2274.62 1.00\n", | |
" drift_decentered[7,0] 0.00 0.08 -0.00 -0.14 0.13 2826.36 1.00\n", | |
" drift_decentered[8,0] 0.01 0.08 0.01 -0.10 0.14 2027.33 1.00\n", | |
" drift_decentered[9,0] 0.01 0.08 0.00 -0.12 0.14 3087.32 1.00\n", | |
" drift_decentered[10,0] 0.01 0.08 0.01 -0.11 0.14 1931.72 1.00\n", | |
" drift_decentered[11,0] 0.01 0.07 0.01 -0.11 0.13 2677.21 1.00\n", | |
" drift_decentered[12,0] 0.01 0.08 0.01 -0.11 0.14 2511.18 1.00\n", | |
" drift_decentered[13,0] 0.01 0.08 0.02 -0.11 0.14 2209.17 1.00\n", | |
" drift_decentered[14,0] 0.01 0.07 0.01 -0.11 0.13 1839.87 1.00\n", | |
" drift_decentered[15,0] 0.01 0.08 0.01 -0.11 0.14 2097.43 1.00\n", | |
" drift_decentered[16,0] 0.01 0.07 0.01 -0.10 0.14 1594.47 1.00\n", | |
" drift_decentered[17,0] 0.02 0.08 0.02 -0.12 0.13 1814.31 1.00\n", | |
" drift_decentered[18,0] 0.01 0.08 0.02 -0.11 0.14 2007.43 1.00\n", | |
" drift_decentered[19,0] 0.01 0.08 0.01 -0.12 0.15 2124.75 1.00\n", | |
" drift_decentered[20,0] 0.01 0.07 0.02 -0.10 0.13 2134.75 1.00\n", | |
" drift_decentered[21,0] 0.01 0.08 0.02 -0.11 0.15 1919.88 1.00\n", | |
" drift_decentered[22,0] 0.02 0.08 0.02 -0.12 0.15 2493.04 1.00\n", | |
" drift_decentered[23,0] 0.02 0.08 0.02 -0.11 0.14 2742.28 1.00\n", | |
" drift_decentered[24,0] 0.02 0.08 0.02 -0.11 0.16 1885.37 1.00\n", | |
" drift_decentered[25,0] 0.02 0.08 0.02 -0.12 0.15 2106.77 1.00\n", | |
" drift_decentered[26,0] 0.02 0.08 0.01 -0.11 0.15 2777.04 1.00\n", | |
" drift_decentered[27,0] 0.01 0.07 0.01 -0.10 0.13 2598.64 1.00\n", | |
" drift_decentered[28,0] 0.01 0.08 0.01 -0.12 0.14 1857.20 1.00\n", | |
" drift_decentered[29,0] 0.01 0.07 0.01 -0.12 0.11 2020.42 1.00\n", | |
" drift_decentered[30,0] 0.01 0.08 0.01 -0.11 0.15 2506.52 1.00\n", | |
" drift_decentered[31,0] 0.01 0.07 0.01 -0.11 0.13 2166.43 1.00\n", | |
" drift_decentered[32,0] 0.01 0.08 0.01 -0.12 0.13 2058.64 1.00\n", | |
" drift_decentered[33,0] 0.01 0.08 0.01 -0.13 0.14 2334.76 1.00\n", | |
" drift_decentered[34,0] 0.01 0.08 0.01 -0.11 0.15 2314.28 1.00\n", | |
" drift_decentered[35,0] 0.01 0.08 0.01 -0.13 0.13 1966.52 1.00\n", | |
" drift_decentered[36,0] 0.01 0.09 0.01 -0.15 0.13 3066.88 1.00\n", | |
" drift_decentered[37,0] 0.00 0.08 0.01 -0.13 0.14 1747.10 1.00\n", | |
" drift_decentered[38,0] 0.01 0.08 0.01 -0.12 0.13 2739.86 1.00\n", | |
" drift_decentered[39,0] 0.01 0.08 0.00 -0.12 0.14 1923.73 1.00\n", | |
" drift_decentered[40,0] 0.01 0.08 0.01 -0.10 0.14 1956.68 1.00\n", | |
" drift_decentered[41,0] 0.02 0.07 0.02 -0.11 0.13 1749.40 1.00\n", | |
" drift_decentered[42,0] 0.03 0.08 0.03 -0.09 0.15 2775.14 1.00\n", | |
" drift_decentered[43,0] 0.02 0.08 0.02 -0.12 0.14 2532.88 1.00\n", | |
" drift_decentered[44,0] 0.02 0.08 0.02 -0.11 0.14 2634.06 1.00\n", | |
" drift_decentered[45,0] 0.03 0.08 0.03 -0.09 0.15 2358.67 1.00\n", | |
" drift_decentered[46,0] 0.03 0.08 0.03 -0.10 0.15 2776.51 1.00\n", | |
" drift_decentered[47,0] 0.03 0.08 0.03 -0.09 0.15 2195.39 1.00\n", | |
" drift_decentered[48,0] 0.03 0.07 0.03 -0.09 0.15 1710.29 1.00\n", | |
" drift_decentered[49,0] 0.03 0.08 0.03 -0.09 0.16 2185.53 1.00\n", | |
" drift_decentered[50,0] 0.03 0.07 0.03 -0.09 0.15 1775.65 1.00\n", | |
" drift_decentered[51,0] 0.03 0.08 0.03 -0.10 0.15 2005.32 1.00\n", | |
" drift_decentered[52,0] 0.02 0.08 0.02 -0.10 0.15 2555.93 1.00\n", | |
" drift_decentered[53,0] 0.01 0.08 0.01 -0.12 0.14 2080.88 1.00\n", | |
" drift_decentered[54,0] 0.01 0.08 0.02 -0.12 0.14 1830.20 1.00\n", | |
" drift_decentered[55,0] 0.01 0.08 0.02 -0.12 0.14 1998.09 1.00\n", | |
" drift_decentered[56,0] 0.01 0.08 0.01 -0.10 0.14 1832.68 1.00\n", | |
" drift_decentered[57,0] 0.02 0.08 0.02 -0.11 0.15 2399.11 1.00\n", | |
" drift_decentered[58,0] 0.02 0.07 0.02 -0.10 0.14 2063.95 1.00\n", | |
" drift_decentered[59,0] 0.01 0.07 0.01 -0.11 0.13 1999.16 1.00\n", | |
" drift_decentered[60,0] 0.01 0.08 0.01 -0.11 0.14 2787.60 1.00\n", | |
" drift_decentered[61,0] 0.01 0.08 0.01 -0.11 0.15 2163.29 1.00\n", | |
" drift_decentered[62,0] 0.01 0.08 0.01 -0.12 0.14 1962.54 1.00\n", | |
" drift_decentered[63,0] 0.01 0.08 0.01 -0.11 0.14 1247.90 1.00\n", | |
" drift_decentered[64,0] 0.01 0.08 0.01 -0.11 0.14 2294.95 1.00\n", | |
" drift_decentered[65,0] 0.02 0.08 0.01 -0.11 0.15 2355.35 1.00\n", | |
" drift_decentered[66,0] 0.02 0.08 0.01 -0.10 0.15 1710.11 1.00\n", | |
" drift_decentered[67,0] 0.02 0.08 0.02 -0.12 0.15 2316.11 1.00\n", | |
" drift_decentered[68,0] 0.02 0.08 0.02 -0.12 0.15 1887.58 1.00\n", | |
" drift_decentered[69,0] 0.02 0.07 0.02 -0.10 0.13 2545.48 1.00\n", | |
" drift_decentered[70,0] 0.02 0.08 0.02 -0.14 0.14 1625.43 1.00\n", | |
" drift_decentered[71,0] 0.02 0.08 0.02 -0.10 0.14 2979.34 1.00\n", | |
" drift_decentered[72,0] 0.02 0.08 0.02 -0.10 0.15 2200.54 1.00\n", | |
" drift_decentered[73,0] 0.02 0.07 0.02 -0.11 0.14 2330.59 1.00\n", | |
" drift_decentered[74,0] 0.02 0.08 0.02 -0.11 0.16 2214.47 1.00\n", | |
" drift_decentered[75,0] 0.02 0.08 0.02 -0.11 0.16 1867.87 1.00\n", | |
" drift_decentered[76,0] 0.03 0.08 0.03 -0.11 0.15 2406.28 1.00\n", | |
" drift_decentered[77,0] 0.03 0.08 0.03 -0.09 0.15 1838.78 1.00\n", | |
" drift_decentered[78,0] 0.02 0.08 0.02 -0.12 0.13 2009.60 1.00\n", | |
" drift_decentered[79,0] 0.02 0.08 0.01 -0.11 0.14 2249.17 1.00\n", | |
" drift_decentered[80,0] 0.02 0.08 0.02 -0.12 0.13 2011.00 1.00\n", | |
" drift_decentered[81,0] 0.02 0.08 0.01 -0.11 0.14 1394.16 1.00\n", | |
" drift_decentered[82,0] 0.02 0.08 0.02 -0.09 0.15 2975.60 1.00\n", | |
" drift_decentered[83,0] 0.02 0.08 0.02 -0.10 0.16 1938.97 1.00\n", | |
" drift_decentered[84,0] 0.02 0.08 0.02 -0.11 0.14 2355.04 1.00\n", | |
" drift_decentered[85,0] 0.02 0.08 0.02 -0.11 0.14 2214.49 1.00\n", | |
" drift_decentered[86,0] 0.02 0.08 0.02 -0.11 0.14 2057.82 1.00\n", | |
" drift_decentered[87,0] 0.02 0.08 0.02 -0.10 0.16 2202.00 1.00\n", | |
" drift_decentered[88,0] 0.02 0.08 0.02 -0.10 0.15 2107.02 1.00\n", | |
" drift_decentered[89,0] 0.02 0.08 0.02 -0.13 0.14 2331.15 1.00\n", | |
" drift_decentered[90,0] 0.02 0.08 0.02 -0.11 0.14 2374.64 1.00\n", | |
" drift_decentered[91,0] 0.02 0.08 0.02 -0.10 0.15 1817.40 1.00\n", | |
" drift_decentered[92,0] 0.02 0.08 0.02 -0.12 0.14 2255.18 1.00\n", | |
" drift_decentered[93,0] 0.01 0.08 0.01 -0.10 0.15 2872.33 1.00\n", | |
" drift_decentered[94,0] 0.01 0.08 0.01 -0.12 0.12 2547.74 1.00\n", | |
" drift_decentered[95,0] -0.00 0.08 -0.00 -0.12 0.13 1736.13 1.00\n", | |
" drift_decentered[96,0] -0.01 0.08 -0.01 -0.13 0.12 2246.23 1.00\n", | |
" drift_decentered[97,0] 0.00 0.07 0.00 -0.11 0.12 1970.27 1.00\n", | |
" drift_decentered[98,0] 0.00 0.08 0.00 -0.11 0.14 2320.80 1.00\n", | |
" drift_decentered[99,0] 0.01 0.08 0.01 -0.12 0.14 2322.81 1.00\n", | |
"drift_decentered[100,0] -0.01 0.07 -0.00 -0.13 0.11 1798.63 1.00\n", | |
"drift_decentered[101,0] -0.00 0.07 -0.00 -0.11 0.13 1964.53 1.00\n", | |
"drift_decentered[102,0] -0.01 0.07 -0.00 -0.13 0.11 1402.21 1.00\n", | |
"drift_decentered[103,0] -0.00 0.08 -0.01 -0.14 0.11 3041.40 1.00\n", | |
"drift_decentered[104,0] -0.00 0.08 0.00 -0.13 0.12 2048.45 1.00\n", | |
"drift_decentered[105,0] -0.00 0.08 -0.00 -0.13 0.13 2501.52 1.00\n", | |
"drift_decentered[106,0] 0.00 0.08 0.00 -0.12 0.12 2223.53 1.00\n", | |
"drift_decentered[107,0] -0.01 0.08 -0.01 -0.13 0.11 2175.94 1.00\n", | |
"drift_decentered[108,0] -0.00 0.07 -0.00 -0.11 0.14 2254.59 1.00\n", | |
"drift_decentered[109,0] 0.00 0.08 0.00 -0.13 0.12 1999.52 1.00\n", | |
"drift_decentered[110,0] 0.00 0.07 -0.00 -0.12 0.12 1656.13 1.00\n", | |
"drift_decentered[111,0] 0.00 0.08 0.00 -0.14 0.12 2371.85 1.00\n", | |
"drift_decentered[112,0] -0.00 0.08 0.00 -0.14 0.11 2022.37 1.00\n", | |
"drift_decentered[113,0] -0.00 0.07 -0.01 -0.13 0.11 1591.65 1.00\n", | |
"drift_decentered[114,0] 0.00 0.08 0.00 -0.12 0.14 2397.95 1.00\n", | |
"drift_decentered[115,0] -0.01 0.08 -0.01 -0.12 0.13 1534.68 1.00\n", | |
"drift_decentered[116,0] -0.01 0.07 -0.01 -0.14 0.11 2088.19 1.00\n", | |
"drift_decentered[117,0] -0.01 0.08 -0.01 -0.16 0.10 1669.96 1.00\n", | |
"drift_decentered[118,0] -0.01 0.08 -0.01 -0.13 0.11 2734.63 1.00\n", | |
"drift_decentered[119,0] -0.02 0.08 -0.02 -0.13 0.11 2288.01 1.00\n", | |
"drift_decentered[120,0] -0.02 0.08 -0.02 -0.16 0.10 2010.31 1.00\n", | |
"drift_decentered[121,0] -0.02 0.08 -0.02 -0.15 0.12 1830.93 1.00\n", | |
"drift_decentered[122,0] -0.03 0.08 -0.03 -0.17 0.10 1942.70 1.00\n", | |
"drift_decentered[123,0] -0.03 0.08 -0.02 -0.16 0.12 1710.74 1.00\n", | |
"drift_decentered[124,0] -0.03 0.08 -0.03 -0.14 0.10 1902.38 1.00\n", | |
"drift_decentered[125,0] -0.03 0.08 -0.03 -0.15 0.10 1834.86 1.00\n", | |
"drift_decentered[126,0] -0.03 0.08 -0.03 -0.16 0.10 1540.11 1.00\n", | |
"drift_decentered[127,0] -0.04 0.08 -0.04 -0.17 0.10 1585.34 1.00\n", | |
"drift_decentered[128,0] -0.04 0.08 -0.04 -0.17 0.09 1134.56 1.00\n", | |
"drift_decentered[129,0] -0.05 0.08 -0.05 -0.17 0.10 1015.88 1.00\n", | |
"drift_decentered[130,0] -0.04 0.08 -0.05 -0.18 0.08 1529.87 1.00\n", | |
"drift_decentered[131,0] 0.04 0.08 0.04 -0.09 0.16 1440.67 1.00\n", | |
"drift_decentered[132,0] 0.03 0.08 0.03 -0.10 0.15 1694.88 1.00\n", | |
"drift_decentered[133,0] 0.04 0.08 0.03 -0.08 0.17 2123.96 1.00\n", | |
"drift_decentered[134,0] 0.03 0.08 0.03 -0.10 0.16 1948.59 1.00\n", | |
"drift_decentered[135,0] 0.03 0.08 0.03 -0.10 0.15 1808.72 1.00\n", | |
"drift_decentered[136,0] 0.03 0.08 0.03 -0.09 0.16 1794.99 1.00\n", | |
"drift_decentered[137,0] 0.03 0.07 0.03 -0.08 0.15 1624.73 1.00\n", | |
"drift_decentered[138,0] 0.03 0.08 0.03 -0.10 0.15 2106.88 1.00\n", | |
"drift_decentered[139,0] 0.02 0.08 0.02 -0.12 0.14 2764.65 1.00\n", | |
"drift_decentered[140,0] 0.01 0.08 0.01 -0.13 0.12 2014.93 1.00\n", | |
"drift_decentered[141,0] -0.00 0.08 -0.00 -0.12 0.12 2051.49 1.00\n", | |
"drift_decentered[142,0] 0.00 0.08 -0.00 -0.13 0.12 2140.06 1.00\n", | |
"drift_decentered[143,0] -0.01 0.08 -0.00 -0.13 0.11 1608.82 1.00\n", | |
"drift_decentered[144,0] -0.01 0.08 -0.01 -0.14 0.12 2735.76 1.00\n", | |
"drift_decentered[145,0] -0.01 0.08 -0.01 -0.14 0.11 1685.73 1.00\n", | |
"drift_decentered[146,0] 0.00 0.08 0.00 -0.12 0.13 2248.70 1.00\n", | |
"drift_decentered[147,0] 0.04 0.08 0.04 -0.10 0.17 2406.98 1.00\n", | |
"drift_decentered[148,0] 0.04 0.07 0.04 -0.08 0.16 1801.11 1.00\n", | |
"drift_decentered[149,0] 0.04 0.08 0.04 -0.08 0.16 2120.64 1.00\n", | |
"drift_decentered[150,0] 0.04 0.08 0.04 -0.10 0.16 2062.52 1.00\n", | |
"drift_decentered[151,0] 0.02 0.08 0.02 -0.11 0.15 2515.49 1.00\n", | |
"drift_decentered[152,0] 0.03 0.08 0.03 -0.10 0.16 2155.51 1.00\n", | |
"drift_decentered[153,0] 0.03 0.08 0.03 -0.11 0.16 1980.07 1.00\n", | |
"drift_decentered[154,0] 0.03 0.08 0.03 -0.10 0.15 1963.91 1.00\n", | |
"drift_decentered[155,0] 0.03 0.08 0.03 -0.09 0.16 2677.89 1.00\n", | |
"drift_decentered[156,0] 0.02 0.07 0.02 -0.10 0.14 2036.48 1.00\n", | |
"drift_decentered[157,0] 0.02 0.07 0.02 -0.09 0.14 2146.23 1.00\n", | |
"drift_decentered[158,0] 0.02 0.07 0.02 -0.09 0.13 2081.89 1.00\n", | |
"drift_decentered[159,0] 0.02 0.08 0.02 -0.11 0.14 1831.78 1.00\n", | |
"drift_decentered[160,0] 0.02 0.08 0.02 -0.09 0.16 2408.98 1.00\n", | |
"drift_decentered[161,0] 0.02 0.08 0.02 -0.10 0.14 2085.32 1.00\n", | |
"drift_decentered[162,0] 0.02 0.08 0.02 -0.09 0.14 2162.36 1.00\n", | |
"drift_decentered[163,0] 0.03 0.08 0.02 -0.12 0.15 2566.22 1.00\n", | |
"drift_decentered[164,0] 0.03 0.07 0.03 -0.10 0.14 2102.31 1.00\n", | |
"drift_decentered[165,0] 0.02 0.08 0.02 -0.11 0.15 2307.18 1.00\n", | |
"drift_decentered[166,0] 0.02 0.08 0.02 -0.12 0.14 1884.98 1.00\n", | |
"drift_decentered[167,0] 0.02 0.08 0.03 -0.10 0.15 2716.65 1.00\n", | |
"drift_decentered[168,0] 0.02 0.07 0.02 -0.10 0.14 2696.14 1.00\n", | |
"drift_decentered[169,0] 0.02 0.08 0.02 -0.09 0.16 1973.99 1.00\n", | |
"drift_decentered[170,0] 0.02 0.08 0.02 -0.10 0.16 2496.70 1.00\n", | |
"drift_decentered[171,0] 0.02 0.08 0.02 -0.12 0.14 2022.09 1.00\n", | |
"drift_decentered[172,0] 0.02 0.08 0.02 -0.10 0.16 2496.27 1.00\n", | |
"drift_decentered[173,0] 0.02 0.08 0.03 -0.11 0.14 1909.92 1.00\n", | |
"drift_decentered[174,0] 0.03 0.08 0.02 -0.09 0.15 1754.82 1.00\n", | |
"drift_decentered[175,0] 0.03 0.08 0.03 -0.09 0.16 2298.14 1.00\n", | |
"drift_decentered[176,0] 0.03 0.07 0.02 -0.10 0.15 2009.08 1.00\n", | |
"drift_decentered[177,0] 0.03 0.08 0.03 -0.11 0.14 2085.97 1.00\n", | |
"drift_decentered[178,0] 0.03 0.08 0.03 -0.12 0.14 2336.24 1.00\n", | |
"drift_decentered[179,0] 0.02 0.08 0.02 -0.10 0.16 2446.19 1.00\n", | |
"drift_decentered[180,0] 0.02 0.07 0.02 -0.09 0.15 2112.43 1.00\n", | |
"drift_decentered[181,0] 0.02 0.07 0.02 -0.09 0.15 1977.64 1.00\n", | |
"drift_decentered[182,0] 0.03 0.07 0.03 -0.07 0.15 1417.35 1.00\n", | |
"drift_decentered[183,0] 0.02 0.08 0.02 -0.11 0.15 2305.05 1.00\n", | |
"drift_decentered[184,0] 0.01 0.08 0.01 -0.11 0.15 2187.09 1.00\n", | |
"drift_decentered[185,0] 0.01 0.08 0.01 -0.11 0.14 1979.78 1.00\n", | |
"drift_decentered[186,0] 0.01 0.08 0.01 -0.12 0.12 1608.55 1.00\n", | |
"drift_decentered[187,0] 0.01 0.08 0.01 -0.10 0.13 1926.19 1.00\n", | |
"drift_decentered[188,0] 0.01 0.08 0.01 -0.12 0.14 2417.16 1.00\n", | |
"drift_decentered[189,0] 0.01 0.08 0.01 -0.11 0.13 2609.61 1.00\n", | |
"drift_decentered[190,0] 0.01 0.07 0.01 -0.11 0.13 2745.14 1.00\n", | |
"drift_decentered[191,0] 0.01 0.08 0.01 -0.11 0.13 2015.62 1.00\n", | |
"drift_decentered[192,0] 0.02 0.08 0.02 -0.10 0.15 2463.93 1.00\n", | |
"drift_decentered[193,0] 0.01 0.07 0.01 -0.12 0.12 1979.96 1.00\n", | |
"drift_decentered[194,0] 0.01 0.08 0.02 -0.12 0.13 2405.08 1.00\n", | |
"drift_decentered[195,0] 0.01 0.08 0.01 -0.12 0.14 1778.21 1.00\n", | |
"drift_decentered[196,0] 0.02 0.09 0.02 -0.14 0.15 1902.61 1.00\n", | |
"drift_decentered[197,0] 0.02 0.08 0.01 -0.10 0.16 1892.38 1.00\n", | |
"drift_decentered[198,0] 0.01 0.08 0.01 -0.13 0.12 1858.10 1.00\n", | |
"drift_decentered[199,0] -0.00 0.07 -0.00 -0.12 0.12 2446.63 1.00\n", | |
"drift_decentered[200,0] 0.00 0.08 0.00 -0.13 0.13 1890.72 1.00\n", | |
"drift_decentered[201,0] -0.00 0.08 -0.00 -0.14 0.11 2212.48 1.00\n", | |
"drift_decentered[202,0] 0.00 0.08 -0.00 -0.13 0.15 2514.32 1.00\n", | |
"drift_decentered[203,0] -0.01 0.08 -0.01 -0.14 0.11 1826.21 1.00\n", | |
"drift_decentered[204,0] -0.01 0.08 -0.01 -0.15 0.10 1877.97 1.00\n", | |
"drift_decentered[205,0] 0.00 0.08 0.00 -0.13 0.13 1970.74 1.00\n", | |
"drift_decentered[206,0] 0.00 0.08 -0.00 -0.11 0.14 2043.26 1.00\n", | |
"drift_decentered[207,0] 0.00 0.08 0.00 -0.12 0.13 1952.14 1.00\n", | |
"drift_decentered[208,0] 0.01 0.07 0.01 -0.12 0.13 2008.58 1.00\n", | |
"drift_decentered[209,0] 0.01 0.07 0.00 -0.10 0.13 3224.89 1.00\n", | |
"drift_decentered[210,0] 0.00 0.08 0.01 -0.12 0.13 2929.19 1.00\n", | |
"drift_decentered[211,0] 0.00 0.07 0.00 -0.12 0.12 1722.44 1.00\n", | |
"drift_decentered[212,0] 0.00 0.09 0.00 -0.14 0.14 2203.02 1.00\n", | |
"drift_decentered[213,0] 0.01 0.08 0.01 -0.12 0.13 2127.38 1.00\n", | |
"drift_decentered[214,0] 0.01 0.07 0.01 -0.11 0.13 1830.46 1.00\n", | |
"drift_decentered[215,0] 0.01 0.08 0.01 -0.10 0.14 1951.80 1.00\n", | |
"drift_decentered[216,0] 0.01 0.08 0.01 -0.11 0.14 1853.59 1.00\n", | |
"drift_decentered[217,0] 0.01 0.08 0.01 -0.12 0.13 1577.28 1.00\n", | |
"drift_decentered[218,0] 0.01 0.07 0.01 -0.12 0.12 2765.80 1.00\n", | |
"drift_decentered[219,0] 0.01 0.07 0.01 -0.12 0.12 2129.92 1.00\n", | |
"drift_decentered[220,0] 0.01 0.08 0.01 -0.11 0.14 2066.79 1.00\n", | |
"drift_decentered[221,0] 0.00 0.08 0.00 -0.13 0.13 2121.65 1.00\n", | |
"drift_decentered[222,0] 0.00 0.08 0.00 -0.12 0.13 1740.52 1.00\n", | |
"drift_decentered[223,0] 0.01 0.08 0.01 -0.11 0.13 2533.83 1.00\n", | |
"drift_decentered[224,0] 0.01 0.08 0.00 -0.11 0.16 1874.91 1.00\n", | |
"drift_decentered[225,0] 0.00 0.08 0.01 -0.13 0.13 2108.99 1.00\n", | |
"drift_decentered[226,0] 0.01 0.08 0.00 -0.12 0.14 2191.95 1.00\n", | |
"drift_decentered[227,0] 0.01 0.08 0.01 -0.14 0.13 2820.79 1.00\n", | |
"drift_decentered[228,0] 0.01 0.08 0.01 -0.12 0.13 1968.55 1.00\n", | |
"drift_decentered[229,0] 0.01 0.07 0.01 -0.12 0.13 2098.27 1.00\n", | |
"drift_decentered[230,0] 0.02 0.08 0.02 -0.11 0.14 2287.92 1.00\n", | |
"drift_decentered[231,0] 0.01 0.08 0.01 -0.13 0.15 2384.30 1.00\n", | |
"drift_decentered[232,0] 0.01 0.08 0.01 -0.13 0.14 1855.01 1.00\n", | |
"drift_decentered[233,0] 0.01 0.07 0.01 -0.11 0.13 2371.62 1.00\n", | |
"drift_decentered[234,0] 0.01 0.08 0.01 -0.12 0.14 1917.48 1.00\n", | |
"drift_decentered[235,0] 0.00 0.08 0.00 -0.12 0.11 2254.28 1.00\n", | |
"drift_decentered[236,0] -0.01 0.08 -0.01 -0.13 0.12 1871.69 1.00\n", | |
"drift_decentered[237,0] -0.00 0.08 -0.00 -0.13 0.11 1856.11 1.00\n", | |
"drift_decentered[238,0] -0.01 0.08 -0.01 -0.14 0.12 2354.03 1.00\n", | |
"drift_decentered[239,0] -0.00 0.08 -0.00 -0.13 0.12 1970.17 1.00\n", | |
"drift_decentered[240,0] -0.00 0.07 0.00 -0.11 0.13 1893.61 1.00\n", | |
"drift_decentered[241,0] 0.00 0.08 0.00 -0.11 0.14 2210.72 1.00\n", | |
"drift_decentered[242,0] 0.00 0.08 0.00 -0.12 0.14 2135.80 1.00\n", | |
"drift_decentered[243,0] 0.00 0.08 0.00 -0.13 0.12 2554.41 1.00\n", | |
"drift_decentered[244,0] -0.00 0.08 -0.00 -0.13 0.13 2055.71 1.00\n", | |
"drift_decentered[245,0] 0.01 0.08 0.01 -0.12 0.13 2078.37 1.00\n", | |
"drift_decentered[246,0] 0.01 0.08 0.01 -0.12 0.14 2090.98 1.00\n", | |
"drift_decentered[247,0] 0.01 0.07 0.01 -0.13 0.11 2075.33 1.00\n", | |
"drift_decentered[248,0] 0.01 0.08 0.01 -0.11 0.13 2967.24 1.00\n", | |
"drift_decentered[249,0] 0.01 0.08 0.01 -0.11 0.14 1886.85 1.00\n", | |
"drift_decentered[250,0] 0.01 0.08 0.01 -0.12 0.13 1857.60 1.00\n", | |
"drift_decentered[251,0] 0.01 0.07 0.00 -0.12 0.12 1676.37 1.00\n", | |
"drift_decentered[252,0] 0.00 0.08 0.00 -0.13 0.14 2476.50 1.00\n", | |
"drift_decentered[253,0] 0.00 0.08 0.00 -0.12 0.14 2244.81 1.00\n", | |
"drift_decentered[254,0] 0.01 0.07 0.00 -0.11 0.12 3151.63 1.00\n", | |
"drift_decentered[255,0] -0.00 0.08 -0.00 -0.13 0.13 2650.73 1.00\n", | |
"drift_decentered[256,0] -0.00 0.08 -0.00 -0.12 0.13 2059.76 1.00\n", | |
"drift_decentered[257,0] 0.00 0.07 0.01 -0.11 0.13 2177.54 1.00\n", | |
"drift_decentered[258,0] -0.00 0.07 -0.00 -0.12 0.11 1819.86 1.00\n", | |
"drift_decentered[259,0] 0.01 0.08 0.00 -0.13 0.13 1870.92 1.00\n", | |
"drift_decentered[260,0] 0.01 0.08 0.01 -0.11 0.13 1972.48 1.00\n", | |
"drift_decentered[261,0] 0.01 0.08 0.00 -0.12 0.13 2648.05 1.00\n", | |
"drift_decentered[262,0] 0.01 0.08 0.01 -0.12 0.14 2107.93 1.00\n", | |
"drift_decentered[263,0] 0.01 0.08 0.01 -0.12 0.13 2010.31 1.00\n", | |
"drift_decentered[264,0] 0.01 0.08 0.01 -0.12 0.15 2289.75 1.00\n", | |
"drift_decentered[265,0] 0.00 0.08 0.00 -0.12 0.13 1988.17 1.00\n", | |
"drift_decentered[266,0] -0.00 0.08 -0.01 -0.14 0.11 2188.06 1.00\n", | |
"drift_decentered[267,0] -0.01 0.07 -0.01 -0.13 0.11 1993.18 1.00\n", | |
"drift_decentered[268,0] -0.01 0.08 -0.01 -0.13 0.11 2204.61 1.00\n", | |
"drift_decentered[269,0] -0.01 0.08 -0.01 -0.14 0.11 2230.24 1.00\n", | |
"drift_decentered[270,0] -0.01 0.08 -0.01 -0.13 0.11 1769.90 1.00\n", | |
"drift_decentered[271,0] -0.01 0.08 -0.01 -0.15 0.11 2336.43 1.00\n", | |
"drift_decentered[272,0] -0.01 0.07 -0.01 -0.12 0.10 2168.20 1.00\n", | |
"drift_decentered[273,0] -0.01 0.07 -0.01 -0.13 0.12 3102.44 1.00\n", | |
"drift_decentered[274,0] -0.00 0.08 -0.00 -0.14 0.13 2418.67 1.00\n", | |
"drift_decentered[275,0] -0.01 0.08 -0.01 -0.14 0.13 2161.31 1.00\n", | |
"drift_decentered[276,0] -0.01 0.08 -0.01 -0.12 0.13 1862.83 1.00\n", | |
"drift_decentered[277,0] -0.01 0.07 -0.01 -0.13 0.11 2816.21 1.00\n", | |
"drift_decentered[278,0] -0.01 0.08 -0.01 -0.15 0.11 3051.01 1.00\n", | |
"drift_decentered[279,0] -0.01 0.07 -0.01 -0.13 0.10 2147.25 1.00\n", | |
"drift_decentered[280,0] -0.01 0.08 -0.01 -0.14 0.11 2749.43 1.00\n", | |
"drift_decentered[281,0] -0.01 0.08 -0.01 -0.14 0.11 2114.90 1.00\n", | |
"drift_decentered[282,0] -0.01 0.07 -0.01 -0.14 0.10 2181.06 1.00\n", | |
"drift_decentered[283,0] -0.01 0.08 -0.01 -0.13 0.12 2433.70 1.00\n", | |
"drift_decentered[284,0] -0.00 0.08 0.00 -0.13 0.13 2611.32 1.00\n", | |
"drift_decentered[285,0] -0.00 0.08 -0.00 -0.12 0.13 2124.17 1.00\n", | |
"drift_decentered[286,0] -0.00 0.08 -0.01 -0.14 0.13 2977.05 1.00\n", | |
"drift_decentered[287,0] -0.01 0.08 -0.01 -0.14 0.11 1762.31 1.00\n", | |
"drift_decentered[288,0] -0.02 0.08 -0.01 -0.15 0.12 3133.33 1.00\n", | |
"drift_decentered[289,0] -0.01 0.08 -0.02 -0.12 0.12 1887.68 1.00\n", | |
"drift_decentered[290,0] -0.01 0.08 -0.01 -0.14 0.11 2091.11 1.00\n", | |
"drift_decentered[291,0] -0.02 0.08 -0.02 -0.14 0.12 2040.60 1.00\n", | |
"drift_decentered[292,0] -0.01 0.08 -0.01 -0.14 0.11 1963.15 1.00\n", | |
"drift_decentered[293,0] -0.01 0.08 -0.01 -0.13 0.12 1821.85 1.00\n", | |
"drift_decentered[294,0] -0.01 0.08 -0.01 -0.15 0.11 2315.90 1.00\n", | |
"drift_decentered[295,0] -0.01 0.07 -0.01 -0.13 0.11 1630.89 1.00\n", | |
"drift_decentered[296,0] -0.01 0.08 -0.00 -0.12 0.13 1629.48 1.00\n", | |
"drift_decentered[297,0] -0.01 0.08 -0.01 -0.14 0.12 2626.41 1.00\n", | |
"drift_decentered[298,0] -0.00 0.08 -0.00 -0.13 0.14 2003.47 1.00\n", | |
"drift_decentered[299,0] -0.00 0.08 -0.01 -0.12 0.13 1785.08 1.00\n", | |
"drift_decentered[300,0] 0.00 0.08 0.00 -0.14 0.11 1992.51 1.00\n", | |
"drift_decentered[301,0] -0.00 0.07 -0.00 -0.13 0.12 2109.78 1.00\n", | |
"drift_decentered[302,0] -0.00 0.07 -0.01 -0.13 0.12 2712.35 1.00\n", | |
"drift_decentered[303,0] -0.01 0.07 -0.01 -0.13 0.11 1653.44 1.00\n", | |
"drift_decentered[304,0] -0.01 0.08 -0.01 -0.14 0.13 3004.20 1.00\n", | |
"drift_decentered[305,0] -0.01 0.08 -0.01 -0.16 0.11 2955.71 1.00\n", | |
"drift_decentered[306,0] -0.01 0.08 -0.01 -0.13 0.13 2201.77 1.00\n", | |
"drift_decentered[307,0] -0.01 0.08 -0.01 -0.14 0.11 2545.16 1.00\n", | |
"drift_decentered[308,0] -0.01 0.08 -0.01 -0.13 0.13 2125.86 1.00\n", | |
"drift_decentered[309,0] -0.01 0.08 -0.01 -0.13 0.11 2319.22 1.00\n", | |
"drift_decentered[310,0] -0.01 0.08 -0.01 -0.14 0.11 2804.59 1.00\n", | |
"drift_decentered[311,0] -0.01 0.08 -0.01 -0.13 0.12 1680.62 1.00\n", | |
"drift_decentered[312,0] -0.02 0.07 -0.02 -0.15 0.09 2111.40 1.00\n", | |
"drift_decentered[313,0] -0.02 0.08 -0.02 -0.15 0.11 1990.56 1.00\n", | |
"drift_decentered[314,0] -0.01 0.08 -0.01 -0.15 0.12 2197.25 1.00\n", | |
"drift_decentered[315,0] -0.01 0.07 -0.01 -0.12 0.11 2203.74 1.00\n", | |
"drift_decentered[316,0] -0.01 0.08 -0.01 -0.14 0.12 2563.04 1.00\n", | |
"drift_decentered[317,0] -0.01 0.08 -0.01 -0.13 0.13 2517.32 1.00\n", | |
"drift_decentered[318,0] -0.01 0.08 -0.01 -0.13 0.12 2542.85 1.00\n", | |
"drift_decentered[319,0] -0.01 0.08 -0.01 -0.14 0.12 3152.07 1.00\n", | |
"drift_decentered[320,0] -0.01 0.08 -0.02 -0.14 0.13 2068.76 1.00\n", | |
"drift_decentered[321,0] -0.01 0.07 -0.01 -0.13 0.11 1590.10 1.00\n", | |
"drift_decentered[322,0] -0.01 0.08 -0.01 -0.13 0.14 1903.17 1.00\n", | |
"drift_decentered[323,0] -0.01 0.08 -0.01 -0.14 0.11 1692.83 1.00\n", | |
"drift_decentered[324,0] -0.01 0.08 -0.01 -0.13 0.11 1619.67 1.00\n", | |
"drift_decentered[325,0] -0.01 0.08 -0.01 -0.14 0.13 2083.94 1.00\n", | |
"drift_decentered[326,0] -0.01 0.08 -0.01 -0.12 0.14 1755.35 1.00\n", | |
"drift_decentered[327,0] -0.01 0.07 -0.01 -0.15 0.10 2749.83 1.00\n", | |
"drift_decentered[328,0] -0.00 0.08 -0.00 -0.14 0.12 2462.57 1.00\n", | |
"drift_decentered[329,0] -0.01 0.07 -0.01 -0.12 0.12 2129.79 1.00\n", | |
"drift_decentered[330,0] -0.01 0.07 -0.01 -0.14 0.11 1991.33 1.00\n", | |
"drift_decentered[331,0] -0.01 0.08 -0.01 -0.13 0.14 2144.75 1.00\n", | |
"drift_decentered[332,0] -0.00 0.08 -0.01 -0.15 0.12 1747.65 1.00\n", | |
"drift_decentered[333,0] -0.00 0.08 -0.01 -0.13 0.12 2256.85 1.00\n", | |
"drift_decentered[334,0] -0.01 0.07 -0.01 -0.13 0.11 2032.04 1.00\n", | |
"drift_decentered[335,0] -0.00 0.07 -0.00 -0.14 0.10 1893.35 1.00\n", | |
"drift_decentered[336,0] -0.00 0.08 -0.00 -0.13 0.12 2272.23 1.00\n", | |
"drift_decentered[337,0] -0.00 0.08 -0.00 -0.15 0.12 1864.94 1.00\n", | |
"drift_decentered[338,0] -0.00 0.08 0.00 -0.12 0.13 2524.94 1.00\n", | |
"drift_decentered[339,0] -0.01 0.08 -0.01 -0.13 0.12 2758.25 1.00\n", | |
"drift_decentered[340,0] -0.01 0.08 -0.01 -0.14 0.14 2419.34 1.00\n", | |
"drift_decentered[341,0] -0.01 0.08 -0.01 -0.15 0.10 1592.87 1.00\n", | |
"drift_decentered[342,0] -0.01 0.08 -0.01 -0.12 0.13 1671.89 1.00\n", | |
"drift_decentered[343,0] -0.01 0.08 -0.01 -0.14 0.11 2073.63 1.00\n", | |
"drift_decentered[344,0] -0.01 0.08 -0.01 -0.15 0.10 2001.33 1.00\n", | |
"drift_decentered[345,0] -0.01 0.08 -0.01 -0.14 0.11 2059.34 1.00\n", | |
"drift_decentered[346,0] -0.01 0.08 -0.01 -0.14 0.11 1953.27 1.00\n", | |
"drift_decentered[347,0] -0.01 0.08 -0.01 -0.16 0.12 2032.89 1.00\n", | |
"drift_decentered[348,0] -0.01 0.08 -0.01 -0.13 0.12 1912.42 1.00\n", | |
"drift_decentered[349,0] -0.00 0.08 -0.00 -0.14 0.10 2352.34 1.00\n", | |
"drift_decentered[350,0] -0.00 0.08 -0.00 -0.12 0.12 2170.38 1.00\n", | |
"drift_decentered[351,0] -0.00 0.08 -0.01 -0.12 0.13 2234.80 1.00\n", | |
"drift_decentered[352,0] -0.00 0.08 -0.00 -0.14 0.13 2765.89 1.00\n", | |
"drift_decentered[353,0] 0.00 0.08 0.00 -0.12 0.13 1772.89 1.00\n", | |
"drift_decentered[354,0] 0.00 0.08 0.00 -0.13 0.14 2685.68 1.00\n", | |
"drift_decentered[355,0] -0.00 0.07 -0.01 -0.11 0.12 1998.39 1.00\n", | |
"drift_decentered[356,0] -0.01 0.08 -0.01 -0.13 0.12 2487.86 1.00\n", | |
"drift_decentered[357,0] -0.00 0.08 -0.00 -0.15 0.11 2307.93 1.00\n", | |
"drift_decentered[358,0] -0.01 0.07 -0.01 -0.12 0.12 2238.92 1.00\n", | |
"drift_decentered[359,0] -0.01 0.07 -0.01 -0.12 0.10 3087.89 1.00\n", | |
"drift_decentered[360,0] -0.00 0.07 -0.00 -0.13 0.11 2121.82 1.00\n", | |
"drift_decentered[361,0] -0.01 0.08 -0.01 -0.14 0.13 2218.94 1.00\n", | |
"drift_decentered[362,0] -0.01 0.08 -0.01 -0.15 0.10 2701.36 1.00\n", | |
"drift_decentered[363,0] -0.01 0.07 -0.01 -0.14 0.10 1882.65 1.00\n", | |
"drift_decentered[364,0] -0.02 0.07 -0.02 -0.15 0.10 1723.27 1.00\n", | |
"drift_decentered[365,0] -0.01 0.08 -0.01 -0.14 0.12 3001.26 1.00\n", | |
"drift_decentered[366,0] -0.00 0.08 -0.00 -0.13 0.12 1785.72 1.00\n", | |
"drift_decentered[367,0] -0.00 0.08 -0.00 -0.14 0.12 1898.94 1.00\n", | |
"drift_decentered[368,0] -0.00 0.08 -0.01 -0.14 0.12 2161.14 1.00\n", | |
"drift_decentered[369,0] -0.00 0.07 -0.00 -0.13 0.10 1931.30 1.00\n", | |
"drift_decentered[370,0] -0.00 0.08 -0.00 -0.14 0.12 2246.09 1.00\n", | |
"drift_decentered[371,0] 0.00 0.08 -0.00 -0.13 0.13 2213.35 1.00\n", | |
"drift_decentered[372,0] -0.00 0.07 -0.00 -0.12 0.12 1852.47 1.00\n", | |
"drift_decentered[373,0] 0.00 0.08 0.00 -0.13 0.12 2500.73 1.00\n", | |
"drift_decentered[374,0] 0.00 0.07 -0.00 -0.13 0.11 1900.90 1.00\n", | |
"drift_decentered[375,0] 0.00 0.08 0.00 -0.13 0.12 2284.26 1.00\n", | |
"drift_decentered[376,0] 0.00 0.08 0.00 -0.13 0.12 2582.45 1.00\n", | |
"drift_decentered[377,0] 0.01 0.08 0.00 -0.12 0.15 2693.59 1.00\n", | |
"drift_decentered[378,0] 0.00 0.08 0.00 -0.12 0.13 1882.20 1.00\n", | |
"drift_decentered[379,0] 0.00 0.07 0.00 -0.12 0.12 1999.26 1.00\n", | |
"drift_decentered[380,0] -0.00 0.08 0.00 -0.12 0.13 1709.70 1.00\n", | |
"drift_decentered[381,0] 0.00 0.08 0.00 -0.13 0.13 2554.62 1.00\n", | |
"drift_decentered[382,0] 0.00 0.07 0.00 -0.12 0.12 1874.54 1.00\n", | |
"drift_decentered[383,0] -0.00 0.07 -0.00 -0.11 0.12 1647.29 1.00\n", | |
"drift_decentered[384,0] -0.00 0.08 -0.00 -0.14 0.12 2162.98 1.00\n", | |
"drift_decentered[385,0] 0.00 0.08 0.00 -0.11 0.14 2066.03 1.00\n", | |
"drift_decentered[386,0] -0.00 0.08 0.00 -0.12 0.13 1966.27 1.00\n", | |
"drift_decentered[387,0] 0.00 0.08 0.00 -0.12 0.13 1742.30 1.00\n", | |
"drift_decentered[388,0] 0.00 0.08 0.00 -0.11 0.13 1900.77 1.00\n", | |
"drift_decentered[389,0] -0.00 0.08 -0.01 -0.14 0.11 3195.64 1.00\n", | |
"drift_decentered[390,0] -0.00 0.08 -0.00 -0.13 0.12 1834.62 1.00\n", | |
"drift_decentered[391,0] -0.00 0.08 -0.01 -0.12 0.12 1786.30 1.00\n", | |
"drift_decentered[392,0] -0.01 0.08 -0.01 -0.12 0.12 2303.43 1.00\n", | |
"drift_decentered[393,0] -0.01 0.08 -0.01 -0.15 0.11 2214.20 1.00\n", | |
"drift_decentered[394,0] -0.01 0.08 -0.01 -0.13 0.11 2009.54 1.00\n", | |
"drift_decentered[395,0] -0.01 0.08 -0.01 -0.14 0.12 2606.63 1.00\n", | |
"drift_decentered[396,0] -0.01 0.07 -0.01 -0.14 0.10 2020.66 1.00\n", | |
"drift_decentered[397,0] -0.02 0.08 -0.02 -0.13 0.13 1817.33 1.00\n", | |
"drift_decentered[398,0] -0.01 0.08 -0.01 -0.14 0.12 2027.84 1.00\n", | |
"drift_decentered[399,0] -0.01 0.08 -0.02 -0.15 0.11 2276.22 1.00\n", | |
"drift_decentered[400,0] -0.01 0.08 -0.01 -0.15 0.11 1823.10 1.00\n", | |
"drift_decentered[401,0] -0.01 0.08 -0.01 -0.14 0.11 2340.22 1.00\n", | |
"drift_decentered[402,0] -0.01 0.08 -0.01 -0.14 0.11 1716.01 1.00\n", | |
"drift_decentered[403,0] -0.01 0.08 -0.00 -0.15 0.11 2211.88 1.00\n", | |
"drift_decentered[404,0] -0.01 0.08 -0.02 -0.13 0.12 2606.33 1.00\n", | |
"drift_decentered[405,0] -0.01 0.08 -0.01 -0.13 0.13 2485.40 1.00\n", | |
"drift_decentered[406,0] -0.01 0.08 -0.01 -0.14 0.10 2287.87 1.00\n", | |
"drift_decentered[407,0] -0.01 0.08 -0.01 -0.14 0.13 2492.25 1.00\n", | |
"drift_decentered[408,0] -0.03 0.07 -0.03 -0.14 0.10 1779.69 1.00\n", | |
"drift_decentered[409,0] -0.02 0.07 -0.02 -0.15 0.09 2111.82 1.00\n", | |
"drift_decentered[410,0] -0.03 0.08 -0.03 -0.15 0.10 3095.77 1.00\n", | |
"drift_decentered[411,0] -0.02 0.08 -0.02 -0.14 0.11 1863.08 1.00\n", | |
"drift_decentered[412,0] -0.00 0.07 -0.00 -0.12 0.11 1627.66 1.00\n", | |
"drift_decentered[413,0] -0.01 0.08 -0.01 -0.13 0.14 2377.45 1.00\n", | |
"drift_decentered[414,0] -0.01 0.08 -0.01 -0.15 0.11 2548.59 1.00\n", | |
"drift_decentered[415,0] -0.01 0.08 -0.01 -0.14 0.12 2221.67 1.00\n", | |
"drift_decentered[416,0] -0.01 0.08 -0.01 -0.14 0.12 2512.91 1.00\n", | |
" noise_scale[0] 0.08 0.00 0.08 0.07 0.08 1619.35 1.00\n", | |
"\n", | |
"Number of divergences: 0\n", | |
"CPU times: user 14min 41s, sys: 2.97 s, total: 14min 44s\n", | |
"Wall time: 14min 45s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"pyro.set_rng_seed(1)\n", | |
"pyro.clear_param_store()\n", | |
"time = torch.arange(float(T2)) / 365\n", | |
"covariates = periodic_features(T2, 365.25 / 7)\n", | |
"forecaster = HMCForecaster(Model2(), data[:T1], covariates[:T1])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 648x216 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"samples = forecaster(data[:T1], covariates, num_samples=1000)\n", | |
"p10, p50, p90 = quantile(samples, (0.1, 0.5, 0.9)).squeeze(-1)\n", | |
"crps = eval_crps(samples, data[T1:])\n", | |
"\n", | |
"plt.figure(figsize=(9, 3))\n", | |
"plt.fill_between(torch.arange(T1, T2), p10, p90, color=\"red\", alpha=0.3)\n", | |
"plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')\n", | |
"plt.plot(data, 'k-', label='truth')\n", | |
"plt.title(\"Total weekly ridership (CRPS = {:0.3g})\".format(crps))\n", | |
"plt.ylabel(\"log(# rides)\")\n", | |
"plt.xlabel(\"Week after 2011-01-01\")\n", | |
"plt.xlim(0, None)\n", | |
"plt.legend(loc=\"best\");" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 648x216 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.figure(figsize=(9, 3))\n", | |
"plt.fill_between(torch.arange(T1, T2), p10, p90, color=\"red\", alpha=0.3)\n", | |
"plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')\n", | |
"plt.plot(torch.arange(T1, T2), data[T1:], 'k-', label='truth')\n", | |
"plt.title(\"Total weekly ridership (CRPS = {:0.3g})\".format(crps))\n", | |
"plt.ylabel(\"log(# rides)\")\n", | |
"plt.xlabel(\"Week after 2011-01-01\")\n", | |
"plt.xlim(T1, None)\n", | |
"plt.legend(loc=\"best\");" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Error" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Warmup: 0%| | 0/2000 [00:00, ?it/s]" | |
] | |
}, | |
{ | |
"ename": "RuntimeError", | |
"evalue": "Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient\nTensor:\n 0\n[ torch.FloatTensor{1} ]\n Trace Shapes: \n Param Sites: \n Sample Sites: \n time dist | \n value 417 | \n bias dist | tensor(1)\n value | 1\n weight dist | tensor(52)\n value | 52\ndrift_scale dist | 1\n value | 1", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m----------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m~/pyro/pyro/poutine/trace_messenger.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mValueError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/messenger.py\u001b[0m in \u001b[0;36m_context_wrap\u001b[0;34m(context, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/messenger.py\u001b[0m in \u001b[0;36m_context_wrap\u001b[0;34m(context, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/messenger.py\u001b[0m in \u001b[0;36m_context_wrap\u001b[0;34m(context, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/nn/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 287\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pyro_context\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 288\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 289\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/envs/pydata/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 529\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_tracing_state\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 530\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 531\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/envs/pydata/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_slow_forward\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 515\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 516\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 517\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/contrib/forecast/forecaster.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, data, covariates)\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzero_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcovariates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-12-971d7db4b79b>\u001b[0m in \u001b[0;36mmodel\u001b[0;34m(self, zero_data, covariates)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mpoutine\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreparam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m\"drift\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mLocScaleReparam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mdrift\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"drift\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNormal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzero_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdrift_scale\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_event\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/primitives.py\u001b[0m in \u001b[0;36msample\u001b[0;34m(name, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;31m# apply the stack and return its return value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 113\u001b[0;31m \u001b[0mapply_stack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 114\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"value\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/runtime.py\u001b[0m in \u001b[0;36mapply_stack\u001b[0;34m(initial_msg)\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 193\u001b[0;31m \u001b[0mframe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_process_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 194\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/messenger.py\u001b[0m in \u001b[0;36m_process_message\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 135\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 136\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/reparam_messenger.py\u001b[0m in \u001b[0;36m_pyro_sample\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0mnew_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mreparam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"name\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"fn\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"value\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/infer/reparam/loc_scale.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, name, fn, obs)\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;32mlambda\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_full\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mevent_shape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 61\u001b[0;31m constraint=constraints.unit_interval)\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"loc\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloc\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcentered\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/primitives.py\u001b[0m in \u001b[0;36mparam\u001b[0;34m(name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"name\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_param\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/runtime.py\u001b[0m in \u001b[0;36m_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0;31m# apply the stack and return its return value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 262\u001b[0;31m \u001b[0mapply_stack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 263\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"value\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/runtime.py\u001b[0m in \u001b[0;36mapply_stack\u001b[0;34m(initial_msg)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 198\u001b[0;31m \u001b[0mdefault_process_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/runtime.py\u001b[0m in \u001b[0;36mdefault_process_message\u001b[0;34m(msg)\u001b[0m\n\u001b[1;32m 158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 159\u001b[0;31m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"value\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"fn\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"args\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"kwargs\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 160\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/params/param_store.py\u001b[0m in \u001b[0;36mget_param\u001b[0;34m(self, name, init_tensor, constraint, event_dim)\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 205\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetdefault\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minit_tensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconstraint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 206\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/params/param_store.py\u001b[0m in \u001b[0;36msetdefault\u001b[0;34m(self, name, init_constrained_value, constraint)\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;31m# get the param, which is guaranteed to exist\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/params/param_store.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0mconstraint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_constraints\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 105\u001b[0;31m \u001b[0mconstrained_value\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform_to\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconstraint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0munconstrained_value\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 106\u001b[0m \u001b[0mconstrained_value\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munconstrained\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mweakref\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mref\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0munconstrained_value\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/envs/pydata/lib/python3.6/site-packages/torch/distributions/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cache_size\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 125\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 126\u001b[0m \u001b[0mx_old\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_old\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cached_x_y\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/envs/pydata/lib/python3.6/site-packages/torch/distributions/transforms.py\u001b[0m in \u001b[0;36m_call\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 364\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 365\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_clipped_sigmoid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 366\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/envs/pydata/lib/python3.6/site-packages/torch/distributions/transforms.py\u001b[0m in \u001b[0;36m_clipped_sigmoid\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 348\u001b[0m \u001b[0mfinfo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfinfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 349\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclamp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msigmoid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmin\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfinfo\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtiny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1.\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mfinfo\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 350\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mRuntimeError\u001b[0m: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient\nTensor:\n 0\n[ torch.FloatTensor{1} ]", | |
"\nDuring handling of the above exception, another exception occurred:\n", | |
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<timed exec>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/contrib/forecast/forecaster.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, model, data, covariates, num_warmup, num_samples, num_chains, dense_mass, jit_compile, max_tree_depth)\u001b[0m\n\u001b[1;32m 327\u001b[0m max_tree_depth=max_tree_depth, max_plate_nesting=max_plate_nesting)\n\u001b[1;32m 328\u001b[0m \u001b[0mmcmc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMCMC\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkernel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwarmup_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_warmup\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_samples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_chains\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_chains\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 329\u001b[0;31m \u001b[0mmcmc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcovariates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 330\u001b[0m \u001b[0;31m# conditions to compute rhat\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 331\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnum_chains\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnum_samples\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnum_chains\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnum_samples\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/messenger.py\u001b[0m in \u001b[0;36m_context_wrap\u001b[0;34m(context, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_context_wrap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/infer/mcmc/api.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 355\u001b[0m \u001b[0mz_flat_acc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_chains\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 356\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidation_enabled\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisable_validation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 357\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchain_id\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 358\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnum_samples\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mchain_id\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 359\u001b[0m \u001b[0mnum_samples\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mchain_id\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/infer/mcmc/api.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 166\u001b[0m for sample in _gen_samples(self.kernel, self.warmup_steps, self.num_samples, hook_w_logging,\n\u001b[1;32m 167\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_chains\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 168\u001b[0;31m *args, **kwargs):\n\u001b[0m\u001b[1;32m 169\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0msample\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m \u001b[0;31m# sample, chain_id\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkernel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcleanup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/infer/mcmc/api.py\u001b[0m in \u001b[0;36m_gen_samples\u001b[0;34m(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs)\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_gen_samples\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkernel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwarmup_steps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_samples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchain_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m \u001b[0mkernel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwarmup_steps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 111\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mkernel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minitial_params\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;31m# yield structure (key, value.shape) of params\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/infer/mcmc/hmc.py\u001b[0m in \u001b[0;36msetup\u001b[0;34m(self, warmup_steps, *args, **kwargs)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 266\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_initialize_model_properties\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 267\u001b[0;31m \u001b[0mpotential_energy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpotential_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minitial_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 268\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cache\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minitial_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpotential_energy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minitial_params\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/infer/mcmc/util.py\u001b[0m in \u001b[0;36m_potential_fn_jit\u001b[0;34m(self, skip_jit_warnings, jit_options, params)\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mskip_jit_warnings\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 282\u001b[0m \u001b[0m_pe_jit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mignore_jit_warnings\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_pe_jit\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 283\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_pe_jit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mjit_options\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 284\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/envs/pydata/lib/python3.6/site-packages/torch/jit/__init__.py\u001b[0m in \u001b[0;36mtrace\u001b[0;34m(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, _force_outplace, _module_class, _compilation_unit)\u001b[0m\n\u001b[1;32m 904\u001b[0m traced = torch._C._create_function_from_trace(name, func, example_inputs,\n\u001b[1;32m 905\u001b[0m \u001b[0mvar_lookup_fn\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 906\u001b[0;31m _force_outplace)\n\u001b[0m\u001b[1;32m 907\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 908\u001b[0m \u001b[0;31m# Check the trace against new traces created from user-specified inputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/envs/pydata/lib/python3.6/contextlib.py\u001b[0m in \u001b[0;36minner\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_recreate_cm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 53\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minner\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/infer/mcmc/util.py\u001b[0m in \u001b[0;36m_pe_jit\u001b[0;34m(*zi)\u001b[0m\n\u001b[1;32m 277\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_pe_jit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mzi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 278\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnames\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mzi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 279\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_potential_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 280\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mskip_jit_warnings\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/infer/mcmc/util.py\u001b[0m in \u001b[0;36m_potential_fn\u001b[0;34m(self, params)\u001b[0m\n\u001b[1;32m 259\u001b[0m \u001b[0mcond_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpoutine\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcondition\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams_constrained\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 260\u001b[0m model_trace = poutine.trace(cond_model).get_trace(*self.model_args,\n\u001b[0;32m--> 261\u001b[0;31m **self.model_kwargs)\n\u001b[0m\u001b[1;32m 262\u001b[0m \u001b[0mlog_joint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_prob_evaluator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_prob\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_trace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 263\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/trace_messenger.py\u001b[0m in \u001b[0;36mget_trace\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[0mCalls\u001b[0m \u001b[0mthis\u001b[0m \u001b[0mpoutine\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mreturns\u001b[0m \u001b[0mits\u001b[0m \u001b[0mtrace\u001b[0m \u001b[0minstead\u001b[0m \u001b[0mof\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mfunction\u001b[0m\u001b[0;31m'\u001b[0m\u001b[0ms\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 184\u001b[0m \"\"\"\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmsngr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/trace_messenger.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0mexc_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexc_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraceback\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0mshapes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmsngr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat_shapes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 169\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mexc_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mu\"{}\\n{}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexc_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshapes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwith_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraceback\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 170\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmsngr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_node\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"_RETURN\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"_RETURN\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"return\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mret\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mret\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/trace_messenger.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m args=args, kwargs=kwargs)\n\u001b[1;32m 164\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mValueError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0mexc_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexc_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraceback\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/messenger.py\u001b[0m in \u001b[0;36m_context_wrap\u001b[0;34m(context, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_context_wrap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/messenger.py\u001b[0m in \u001b[0;36m_context_wrap\u001b[0;34m(context, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_context_wrap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/messenger.py\u001b[0m in \u001b[0;36m_context_wrap\u001b[0;34m(context, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_context_wrap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/nn/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 286\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 287\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pyro_context\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 288\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 289\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 290\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__getattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/envs/pydata/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 528\u001b[0m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 529\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_tracing_state\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 530\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 531\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 532\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/envs/pydata/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_slow_forward\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 514\u001b[0m \u001b[0mrecording_scopes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 515\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 516\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 517\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 518\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrecording_scopes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/contrib/forecast/forecaster.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, data, covariates)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forecast\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzero_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcovariates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forecast\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\".predict() was not called by .model()\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-12-971d7db4b79b>\u001b[0m in \u001b[0;36mmodel\u001b[0;34m(self, zero_data, covariates)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;31m# correct if you removed this context manager, but the fit appears to be worse.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mpoutine\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreparam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m\"drift\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mLocScaleReparam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mdrift\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"drift\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNormal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzero_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdrift_scale\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_event\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;31m# After we sample the iid \"drift\" noise we can combine it in any time-dependent way.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/primitives.py\u001b[0m in \u001b[0;36msample\u001b[0;34m(name, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"is_observed\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;31m# apply the stack and return its return value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 113\u001b[0;31m \u001b[0mapply_stack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 114\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"value\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/runtime.py\u001b[0m in \u001b[0;36mapply_stack\u001b[0;34m(initial_msg)\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0mpointer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpointer\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 193\u001b[0;31m \u001b[0mframe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_process_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 194\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"stop\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/messenger.py\u001b[0m in \u001b[0;36m_process_message\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0mmethod_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"_pyro_{}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"type\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 135\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 136\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/reparam_messenger.py\u001b[0m in \u001b[0;36m_pyro_sample\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0mreparam\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margs_kwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_args_kwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0mnew_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mreparam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"name\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"fn\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"value\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mreparam\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margs_kwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/infer/reparam/loc_scale.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, name, fn, obs)\u001b[0m\n\u001b[1;32m 59\u001b[0m centered = pyro.param(\"{}_centered\",\n\u001b[1;32m 60\u001b[0m \u001b[0;32mlambda\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_full\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mevent_shape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 61\u001b[0;31m constraint=constraints.unit_interval)\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"loc\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloc\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcentered\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"scale\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscale\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0mcentered\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/primitives.py\u001b[0m in \u001b[0;36mparam\u001b[0;34m(name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \"\"\"\n\u001b[1;32m 60\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"name\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 61\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_param\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/runtime.py\u001b[0m in \u001b[0;36m_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 260\u001b[0m }\n\u001b[1;32m 261\u001b[0m \u001b[0;31m# apply the stack and return its return value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 262\u001b[0;31m \u001b[0mapply_stack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 263\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"value\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 264\u001b[0m \u001b[0m_fn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_is_effectful\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/runtime.py\u001b[0m in \u001b[0;36mapply_stack\u001b[0;34m(initial_msg)\u001b[0m\n\u001b[1;32m 196\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 198\u001b[0;31m \u001b[0mdefault_process_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mframe\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mstack\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mpointer\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/poutine/runtime.py\u001b[0m in \u001b[0;36mdefault_process_message\u001b[0;34m(msg)\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 159\u001b[0;31m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"value\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"fn\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"args\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"kwargs\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 160\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;31m# after fn has been called, update msg to prevent it from being called again.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/params/param_store.py\u001b[0m in \u001b[0;36mget_param\u001b[0;34m(self, name, init_tensor, constraint, event_dim)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 205\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetdefault\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minit_tensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconstraint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 206\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/params/param_store.py\u001b[0m in \u001b[0;36msetdefault\u001b[0;34m(self, name, init_constrained_value, constraint)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;31m# get the param, which is guaranteed to exist\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;31m# -------------------------------------------------------------------------------\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/params/param_store.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[0;31m# compute the constrained value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0mconstraint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_constraints\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 105\u001b[0;31m \u001b[0mconstrained_value\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform_to\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconstraint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0munconstrained_value\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 106\u001b[0m \u001b[0mconstrained_value\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munconstrained\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mweakref\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mref\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0munconstrained_value\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/envs/pydata/lib/python3.6/site-packages/torch/distributions/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 123\u001b[0m \"\"\"\n\u001b[1;32m 124\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cache_size\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 125\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 126\u001b[0m \u001b[0mx_old\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_old\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cached_x_y\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mx_old\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/envs/pydata/lib/python3.6/site-packages/torch/distributions/transforms.py\u001b[0m in \u001b[0;36m_call\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 363\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 364\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 365\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_clipped_sigmoid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 366\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 367\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_inverse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/miniconda3/envs/pydata/lib/python3.6/site-packages/torch/distributions/transforms.py\u001b[0m in \u001b[0;36m_clipped_sigmoid\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 347\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_clipped_sigmoid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 348\u001b[0m \u001b[0mfinfo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfinfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 349\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclamp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msigmoid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmin\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfinfo\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtiny\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1.\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mfinfo\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 350\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 351\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mRuntimeError\u001b[0m: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient\nTensor:\n 0\n[ torch.FloatTensor{1} ]\n Trace Shapes: \n Param Sites: \n Sample Sites: \n time dist | \n value 417 | \n bias dist | tensor(1)\n value | 1\n weight dist | tensor(52)\n value | 52\ndrift_scale dist | 1\n value | 1" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"pyro.set_rng_seed(1)\n", | |
"pyro.clear_param_store()\n", | |
"time = torch.arange(float(T2)) / 365\n", | |
"covariates = periodic_features(T2, 365.25 / 7)\n", | |
"forecaster = HMCForecaster(Model2(), data[:T1], covariates[:T1], jit_compile=True)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Heavy-tailed noise\n", | |
"\n", | |
"Our final univariate model will generalize from Gaussian noise to heavy-tailed [Stable](http://docs.pyro.ai/en/latest/distributions.html#stable) noise. The only difference is the `noise_dist` which now takes two new parameters: `stability` determines tail weight and `skew` determines the relative size of positive versus negative spikes.\n", | |
"\n", | |
"The [Stable distribution](https://en.wikipedia.org/wiki/Stable_distribution) is a natural heavy-tailed generalization of the Normal distribution, but it is difficult to work with due to its intractible density function. Pyro implements auxiliary variable methods for working with Stable distributions. To inform Pyro to use those auxiliary variable methods, we wrap the final line in [poutine.reparam()](http://docs.pyro.ai/en/latest/poutine.html#pyro.poutine.handlers.reparam) effect handler that applies the [StableReparam](http://docs.pyro.ai/en/latest/infer.reparam.html#pyro.infer.reparam.stable.StableReparam) transform to the implicit observe site named \"residual\". You can use Stable distributions for other sites by specifying `config={\"my_site_name\": StableReparam()}`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Model3(ForecastingModel):\n", | |
" def model(self, zero_data, covariates):\n", | |
" data_dim = zero_data.size(-1)\n", | |
" feature_dim = covariates.size(-1)\n", | |
" bias = pyro.sample(\"bias\", dist.Normal(0, 10).expand([data_dim]).to_event(1))\n", | |
" weight = pyro.sample(\"weight\", dist.Normal(0, 0.1).expand([feature_dim]).to_event(1))\n", | |
"\n", | |
" drift_scale = pyro.sample(\"drift_scale\", dist.LogNormal(-20, 5).expand([1]).to_event(1))\n", | |
" with self.time_plate:\n", | |
" with poutine.reparam(config={\"drift\": LocScaleReparam()}):\n", | |
" drift = pyro.sample(\"drift\", dist.Normal(zero_data, drift_scale).to_event(1))\n", | |
" motion = drift.cumsum(-2) # A Brownian motion.\n", | |
" \n", | |
" prediction = motion + bias + (weight * covariates).sum(-1, keepdim=True)\n", | |
" assert prediction.shape[-2:] == zero_data.shape\n", | |
"\n", | |
" # The next part of the model creates a likelihood or noise distribution.\n", | |
" # Again we'll be Bayesian and write this as a probabilistic program with\n", | |
" # priors over parameters, and again we'll use zero_data as a noise template.\n", | |
" stability = pyro.sample(\"noise_stability\", dist.Uniform(1, 2).expand([1]).to_event(1))\n", | |
" skew = pyro.sample(\"noise_skew\", dist.Uniform(-1, 1).expand([1]).to_event(1))\n", | |
" scale = pyro.sample(\"noise_scale\", dist.LogNormal(-5, 5).expand([1]).to_event(1))\n", | |
" noise_dist = dist.Stable(stability, skew, scale, zero_data)\n", | |
"\n", | |
" # We need to use a reparameterizer to handle the Stable distribution.\n", | |
" # Note \"residual\" is the name of Pyro's internal sample site in self.predict().\n", | |
" with poutine.reparam(config={\"residual\": StableReparam()}):\n", | |
" self.predict(noise_dist, prediction) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%time\n", | |
"pyro.set_rng_seed(2)\n", | |
"pyro.clear_param_store()\n", | |
"time = torch.arange(float(T2)) / 365\n", | |
"covariates = periodic_features(T2, 365.25 / 7)\n", | |
"forecaster = HMCForecaster(Model3(), data[:T1], covariates[:T1])\n", | |
"for name, value in forecaster.guide.median().items():\n", | |
" if value.numel() == 1:\n", | |
" print(\"{} = {:0.4g}\".format(name, value.item()))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 648x216 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"samples = forecaster(data[:T1], covariates, num_samples=1000)\n", | |
"p10, p50, p90 = quantile(samples, (0.1, 0.5, 0.9)).squeeze(-1)\n", | |
"crps = eval_crps(samples, data[T1:])\n", | |
"\n", | |
"plt.figure(figsize=(9, 3))\n", | |
"plt.fill_between(torch.arange(T1, T2), p10, p90, color=\"red\", alpha=0.3)\n", | |
"plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')\n", | |
"plt.plot(data, 'k-', label='truth')\n", | |
"plt.title(\"Total weekly ridership (CRPS = {:0.3g})\".format(crps))\n", | |
"plt.ylabel(\"log(# rides)\")\n", | |
"plt.xlabel(\"Week after 2011-01-01\")\n", | |
"plt.xlim(0, None)\n", | |
"plt.legend(loc=\"best\");" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 648x216 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.figure(figsize=(9, 3))\n", | |
"plt.fill_between(torch.arange(T1, T2), p10, p90, color=\"red\", alpha=0.3)\n", | |
"plt.plot(torch.arange(T1, T2), p50, 'r-', label='forecast')\n", | |
"plt.plot(torch.arange(T1, T2), data[T1:], 'k-', label='truth')\n", | |
"plt.title(\"Total weekly ridership (CRPS = {:0.3g})\".format(crps))\n", | |
"plt.ylabel(\"log(# rides)\")\n", | |
"plt.xlabel(\"Week after 2011-01-01\")\n", | |
"plt.xlim(T1, None)\n", | |
"plt.legend(loc=\"best\");" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Backtesting\n", | |
"\n", | |
"To compare our Gaussian `Model2` and Stable `Model3` we'll use a simple [backtesting()](http://docs.pyro.ai/en/latest/contrib.forecast.html#pyro.contrib.forecast.evaluate.backtest) helper. This helper by default evaluates three metrics: [CRPS](http://docs.pyro.ai/en/latest/contrib.forecast.html#pyro.contrib.forecast.evaluate.eval_crps) assesses distributional accuracy of heavy-tailed data, [MAE](http://docs.pyro.ai/en/latest/contrib.forecast.html#pyro.contrib.forecast.evaluate.eval_mae) assesses point accuracy of heavy-tailed data, and [RMSE](http://docs.pyro.ai/en/latest/contrib.forecast.html#pyro.contrib.forecast.evaluate.eval_rmse) assesses accuracy of Normal-tailed data. The one nuance here is to set `warm_start=True` to reduce the need for random restarts." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO \t Training on window [0:104], testing on window [104:156]\n" | |
] | |
}, | |
{ | |
"ename": "TypeError", | |
"evalue": "__init__() got an unexpected keyword argument 'learning_rate'", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m----------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<timed exec>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/contrib/forecast/evaluate.py\u001b[0m in \u001b[0;36mbacktest\u001b[0;34m(data, covariates, model_fn, forecaster_fn, metrics, transform, train_window, min_train_window, test_window, min_test_window, stride, seed, num_samples, forecaster_options)\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m forecaster = forecaster_fn(model, train_data, train_covariates,\n\u001b[0;32m--> 172\u001b[0;31m **forecaster_options)\n\u001b[0m\u001b[1;32m 173\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 174\u001b[0m \u001b[0;31m# Forecast forward to testing window.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mTypeError\u001b[0m: __init__() got an unexpected keyword argument 'learning_rate'" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"pyro.set_rng_seed(1)\n", | |
"pyro.clear_param_store()\n", | |
"windows2 = backtest(data, covariates, Model2, forecaster_fn=HMCForecaster,\n", | |
" min_train_window=104, test_window=52, stride=26,\n", | |
" forecaster_options={\"learning_rate\": 0.1, \"log_every\": 1000,\n", | |
" \"warm_start\": True})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO \t Training on window [0:104], testing on window [104:156]\n" | |
] | |
}, | |
{ | |
"ename": "TypeError", | |
"evalue": "__init__() got an unexpected keyword argument 'learning_rate'", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m----------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<timed exec>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n", | |
"\u001b[0;32m~/pyro/pyro/contrib/forecast/evaluate.py\u001b[0m in \u001b[0;36mbacktest\u001b[0;34m(data, covariates, model_fn, forecaster_fn, metrics, transform, train_window, min_train_window, test_window, min_test_window, stride, seed, num_samples, forecaster_options)\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m forecaster = forecaster_fn(model, train_data, train_covariates,\n\u001b[0;32m--> 172\u001b[0;31m **forecaster_options)\n\u001b[0m\u001b[1;32m 173\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 174\u001b[0m \u001b[0;31m# Forecast forward to testing window.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mTypeError\u001b[0m: __init__() got an unexpected keyword argument 'learning_rate'" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"pyro.set_rng_seed(1)\n", | |
"pyro.clear_param_store()\n", | |
"windows3 = backtest(data, covariates, Model3, forecaster_fn=HMCForecaster,\n", | |
" min_train_window=104, test_window=52, stride=26,\n", | |
" forecaster_options={\"learning_rate\": 0.1, \"log_every\": 1000,\n", | |
" \"warm_start\": True})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "NameError", | |
"evalue": "name 'windows2' is not defined", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m----------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-22-c97ad25de0fd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msubplots\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfigsize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msharex\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0maxes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_title\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Gaussian versus Stable accuracy over {} windows\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwindows2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0maxes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"crps\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mw\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mwindows2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"b<\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Gaussian\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0maxes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"crps\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mw\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mwindows3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"r>\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Stable\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0maxes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_ylabel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"CRPS\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mNameError\u001b[0m: name 'windows2' is not defined" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 576x432 with 3 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"fig, axes = plt.subplots(3, figsize=(8, 6), sharex=True)\n", | |
"axes[0].set_title(\"Gaussian versus Stable accuracy over {} windows\".format(len(windows2)))\n", | |
"axes[0].plot([w[\"crps\"] for w in windows2], \"b<\", label=\"Gaussian\")\n", | |
"axes[0].plot([w[\"crps\"] for w in windows3], \"r>\", label=\"Stable\")\n", | |
"axes[0].set_ylabel(\"CRPS\")\n", | |
"axes[1].plot([w[\"mae\"] for w in windows2], \"b<\", label=\"Gaussian\")\n", | |
"axes[1].plot([w[\"mae\"] for w in windows3], \"r>\", label=\"Stable\")\n", | |
"axes[1].set_ylabel(\"MAE\")\n", | |
"axes[2].plot([w[\"rmse\"] for w in windows2], \"b<\", label=\"Gaussian\")\n", | |
"axes[2].plot([w[\"rmse\"] for w in windows3], \"r>\", label=\"Stable\")\n", | |
"axes[2].set_ylabel(\"RMSE\")\n", | |
"axes[0].legend(loc=\"best\")\n", | |
"plt.tight_layout()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Note that RMSE is a poor metric for evaluating heavy-tailed data. Our stable model has such heavy tails that its variance is infinite, so we cannot expect RMSE to converge, hence occasional outlying points." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment