Last active
February 26, 2018 00:29
-
-
Save dslaw/edae9837ade4733ca19861817c66bf6b to your computer and use it in GitHub Desktop.
Plot state sequence with generative distributions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from scipy.stats import norm | |
import numpy as np | |
import matplotlib.pyplot as plt | |
plt.style.use("ggplot") | |
def save_and_close(fig, filename): | |
fig.savefig(filename, dpi=300) | |
plt.close(fig) | |
return | |
def remove_ticks(ax, xaxis=False, yaxis=False): | |
if xaxis: | |
plt.setp(ax.get_xticklabels(), visible=False) | |
plt.setp(ax.xaxis.get_majorticklines(), visible=False) | |
if yaxis: | |
plt.setp(ax.get_yticklabels(), visible=False) | |
plt.setp(ax.yaxis.get_majorticklines(), visible=False) | |
return ax | |
def share_axes(*axes, xaxis=False, yaxis=False): | |
ax = axes[0] | |
if xaxis: | |
ax.get_shared_x_axes().join(*axes) | |
if yaxis: | |
ax.get_shared_y_axes().join(*axes) | |
return | |
# Simulate from a three-state Hidden Markov Model with Gaussian emissions. | |
n_components = 3 | |
n_draws = 200 | |
initial_state = 0 | |
transmat = np.array([ | |
[.75, .25, 0], | |
[.25, .5, .25], | |
[.05, .2, .75], | |
]) | |
means = np.array([25, 35, 60]) | |
stds = np.array([1.41, 2.23, 3.9]) | |
rs = np.random.RandomState(13) | |
labels = np.full(n_draws, initial_state, dtype=np.int) | |
data = np.empty(n_draws, dtype=np.float) | |
for t in range(n_draws - 1): | |
i = labels[t] | |
labels[t + 1] = rs.choice(n_components, p=transmat[i, :]) | |
for t in range(n_draws): | |
k = labels[t] | |
data[t] = rs.normal(means[k], stds[k]) | |
# Variables for plotting. | |
n_points = 250 | |
line_color = (.5, .5, .5, .25) | |
line_kwds = {"alpha": .5, "linewidth": 1.2} | |
scatter_kwds = {"s": 5} | |
time_index = np.arange(len(data)) | |
# Plot the time-series with labeled states and densities | |
# corresponding to each state's distribution. | |
# Create figure and subplots to be drawn on. | |
fig = plt.figure() | |
gs = plt.GridSpec(1, 2, width_ratios=[4, 1]) | |
ax_ts = fig.add_subplot(gs[0, 0]) | |
ax_density = fig.add_subplot(gs[0, 1], sharey=ax_ts) | |
# Plot time-series. | |
# Giving each state its own collection will cause | |
# matplotlib to cycle through colors from the global | |
# palette. | |
ax_ts.plot(time_index, data, c=line_color) | |
for k in np.unique(labels): | |
mask = labels == k | |
ax_ts.scatter(time_index[mask], data[mask], **scatter_kwds) | |
ax_ts.set_xlabel("Time") | |
# Plot densities with data on y-axis. | |
ymin, ymax = ax_ts.get_ybound() | |
ypts = np.linspace(ymin, ymax, n_points) | |
# Draw in the same order as the time-series | |
# states - matplotlib will cycle through the color | |
# palette in the same order since we're drawing on | |
# a new axis. | |
for k in np.unique(labels): | |
densities = norm.pdf(ypts, loc=means[k], scale=stds[k]) | |
ax_density.plot(densities, ypts) | |
# Remove ticks from the density subplot. | |
remove_ticks(ax_density, xaxis=True, yaxis=True) | |
save_and_close(fig, "state-sequence-1.png") | |
# It may be of interest to show each state as it's own signal. | |
# Create figure and subplots to be drawn on. | |
fig = plt.figure() | |
gs = plt.GridSpec(1, 2, width_ratios=[4, 1]) | |
ax_ts = fig.add_subplot(gs[0, 0]) | |
ax_density = fig.add_subplot(gs[0, 1], sharey=ax_ts) | |
# Plot time-series. | |
for k in np.unique(labels): | |
mask = labels == k | |
ax_ts.plot(time_index[mask], data[mask], **line_kwds) | |
ax_ts.scatter(time_index[mask], data[mask], **scatter_kwds) | |
ax_ts.set_xlabel("Time") | |
# Plot densities with data on y-axis. | |
ymin, ymax = ax_ts.get_ybound() | |
ypts = np.linspace(ymin, ymax, n_points) | |
# Plot densities with data on y-axis. | |
ymin, ymax = ax_ts.get_ybound() | |
ypts = np.linspace(ymin, ymax, n_points) | |
for k in np.unique(labels): | |
densities = norm.pdf(ypts, loc=means[k], scale=stds[k]) | |
ax_density.plot(densities, ypts) | |
# Remove ticks from the density subplot. | |
remove_ticks(ax_density, xaxis=True, yaxis=True) | |
save_and_close(fig, "state-sequence-2.png") | |
# Take it to the conclusion, and place each state on its | |
# own subplot. | |
cmap = plt.get_cmap("Dark2") # Not sure what ggplot uses... | |
# Order states by mean so that subplot order aligns with | |
# visual expectation (i.e. can be read top-down/bottom-up). | |
states = np.argsort(means)[::-1] | |
K = len(states) | |
# Create figure and subplots to be drawn on. | |
# One subplot for each state's time-series, and one for | |
# all the distributions. | |
fig = plt.figure() | |
gs = plt.GridSpec(K, 2, width_ratios=[4, 1]) | |
axes_ts = [fig.add_subplot(gs[k, 0]) for k in range(K)] | |
share_axes(*axes_ts, xaxis=True, yaxis=True) | |
ax_density = fig.add_subplot(gs[:, 1]) | |
# Plot each time-series. | |
for k, ax in zip(states, axes_ts): | |
mask = labels == k | |
colors = cmap(labels[mask]) | |
ax.plot(time_index[mask], data[mask], c=cmap(k), **line_kwds) | |
ax.scatter(time_index[mask], data[mask], c=colors, **scatter_kwds) | |
# Leave x-ticks on the bottom-most axis in the column. | |
if ax is not axes_ts[-1]: | |
remove_ticks(ax, xaxis=True, yaxis=False) | |
axes_ts[-1].set_xlabel("Time") | |
# Draw density for each component over the entire range. | |
ybounds = np.array([ax.get_ybound() for ax in axes_ts]) | |
ymin, _ = np.min(ybounds, axis=0) | |
_, ymax = np.max(ybounds, axis=0) | |
ypts = np.linspace(ymin, ymax, n_points) | |
for k in states: | |
densities = norm.pdf(ypts, loc=means[k], scale=stds[k]) | |
ax_density.plot(densities, ypts, c=cmap(k)) | |
# Clean up density plot, but keep y-axis labels as ticks no | |
# longer align with the time-series. | |
remove_ticks(ax_density, xaxis=True, yaxis=False) | |
ax_density.yaxis.tick_right() | |
save_and_close(fig, "state-sequence-3.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment