Last active
December 17, 2020 20:21
-
-
Save WillianFuks/78495858bb85d57367a22be4ca08f7bd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Dict | |
| import tensorflow_probability as tfp | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| def plot_components(index, one_step_dists: Dict[str, tfp.distributions.Distribution], | |
| forecast_dists: Dict[str, tfp.distributions.Distribution], | |
| mu_sig=None): | |
| """ | |
| Helper function to plot the final states contributions for each component used in | |
| input model. | |
| Args | |
| ---- | |
| index: pandas index | |
| Should be the index from pre-intervention data merged to the post-intervention index | |
| one_step_dists: Dict[str, tfp.distributions.Distribution] | |
| One step predictive distribution obtained from decomposition of components for training data. | |
| forecast_dists: Dict[str, tfp.distributions.Distribution] | |
| Forecasts distribution obtained from decomposition of forecasts distributions using the posterior. | |
| """ | |
| c0, c1 = 'blue', 'orangered' | |
| num_components = len(one_step_dists) | |
| fig = plt.figure(figsize=(12, 2.5 * num_components)) | |
| mu, sig = mu_sig if mu_sig else (0, 1) | |
| for i, component in enumerate(one_step_dists.keys()): | |
| name = component.name | |
| pre_dist = one_step_dists[component] | |
| post_dist = forecast_dists[component] | |
| component_means = np.concatenate([pre_dist.mean(), post_dist.mean()], axis=-1) | |
| component_stddevs = np.concatenate([pre_dist.stddev(), post_dist.stddev()], axis=-1) | |
| ax = fig.add_subplot(num_components, 1, 1 + i) | |
| ax.plot(index, component_means * sig + mu, lw=2, c=c0) | |
| ax.fill_between(index, (component_means - 1.96 * component_stddevs) * sig + mu, | |
| (component_means + 1.96 * component_stddevs) * sig + mu, | |
| color=c1, alpha=0.3) | |
| ax.set_title(name) | |
| fig.tight_layout() | |
| plt.show() | |
| pre_data = ( | |
| ci.normed_pre_data.iloc[:, 0].astype(np.float32) if ci.normed_pre_data is not None else | |
| ci.pre_data.iloc[:, 0].astype(np.float32) | |
| ) | |
| one_step_dists = tfp.sts.decompose_by_component(ci.model, pre_data, ci.model_samples) | |
| forecast_dists = tfp.sts.decompose_forecast_by_component(ci.model, ci.posterior_dist, ci.model_samples) | |
| plot_components(ci.pre_data.index.union(ci.post_data.index), one_step_dists, forecast_dists, ci.mu_sig) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment