Skip to content

Instantly share code, notes, and snippets.

@WillianFuks
Last active December 17, 2020 20:21
Show Gist options
  • Select an option

  • Save WillianFuks/78495858bb85d57367a22be4ca08f7bd to your computer and use it in GitHub Desktop.

Select an option

Save WillianFuks/78495858bb85d57367a22be4ca08f7bd to your computer and use it in GitHub Desktop.
from typing import Dict
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import numpy as np
def plot_components(index, one_step_dists: Dict[str, tfp.distributions.Distribution],
forecast_dists: Dict[str, tfp.distributions.Distribution],
mu_sig=None):
"""
Helper function to plot the final states contributions for each component used in
input model.
Args
----
index: pandas index
Should be the index from pre-intervention data merged to the post-intervention index
one_step_dists: Dict[str, tfp.distributions.Distribution]
One step predictive distribution obtained from decomposition of components for training data.
forecast_dists: Dict[str, tfp.distributions.Distribution]
Forecasts distribution obtained from decomposition of forecasts distributions using the posterior.
"""
c0, c1 = 'blue', 'orangered'
num_components = len(one_step_dists)
fig = plt.figure(figsize=(12, 2.5 * num_components))
mu, sig = mu_sig if mu_sig else (0, 1)
for i, component in enumerate(one_step_dists.keys()):
name = component.name
pre_dist = one_step_dists[component]
post_dist = forecast_dists[component]
component_means = np.concatenate([pre_dist.mean(), post_dist.mean()], axis=-1)
component_stddevs = np.concatenate([pre_dist.stddev(), post_dist.stddev()], axis=-1)
ax = fig.add_subplot(num_components, 1, 1 + i)
ax.plot(index, component_means * sig + mu, lw=2, c=c0)
ax.fill_between(index, (component_means - 1.96 * component_stddevs) * sig + mu,
(component_means + 1.96 * component_stddevs) * sig + mu,
color=c1, alpha=0.3)
ax.set_title(name)
fig.tight_layout()
plt.show()
pre_data = (
ci.normed_pre_data.iloc[:, 0].astype(np.float32) if ci.normed_pre_data is not None else
ci.pre_data.iloc[:, 0].astype(np.float32)
)
one_step_dists = tfp.sts.decompose_by_component(ci.model, pre_data, ci.model_samples)
forecast_dists = tfp.sts.decompose_forecast_by_component(ci.model, ci.posterior_dist, ci.model_samples)
plot_components(ci.pre_data.index.union(ci.post_data.index), one_step_dists, forecast_dists, ci.mu_sig)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment