Skip to content

Instantly share code, notes, and snippets.

@dslaw
Last active February 26, 2018 00:29
Plot state sequence with generative distributions
from scipy.stats import norm
import numpy as np
import matplotlib.pyplot as plt
plt.style.use("ggplot")
def save_and_close(fig, filename):
fig.savefig(filename, dpi=300)
plt.close(fig)
return
def remove_ticks(ax, xaxis=False, yaxis=False):
if xaxis:
plt.setp(ax.get_xticklabels(), visible=False)
plt.setp(ax.xaxis.get_majorticklines(), visible=False)
if yaxis:
plt.setp(ax.get_yticklabels(), visible=False)
plt.setp(ax.yaxis.get_majorticklines(), visible=False)
return ax
def share_axes(*axes, xaxis=False, yaxis=False):
ax = axes[0]
if xaxis:
ax.get_shared_x_axes().join(*axes)
if yaxis:
ax.get_shared_y_axes().join(*axes)
return
# Simulate from a three-state Hidden Markov Model with Gaussian emissions.
n_components = 3
n_draws = 200
initial_state = 0
transmat = np.array([
[.75, .25, 0],
[.25, .5, .25],
[.05, .2, .75],
])
means = np.array([25, 35, 60])
stds = np.array([1.41, 2.23, 3.9])
rs = np.random.RandomState(13)
labels = np.full(n_draws, initial_state, dtype=np.int)
data = np.empty(n_draws, dtype=np.float)
for t in range(n_draws - 1):
i = labels[t]
labels[t + 1] = rs.choice(n_components, p=transmat[i, :])
for t in range(n_draws):
k = labels[t]
data[t] = rs.normal(means[k], stds[k])
# Variables for plotting.
n_points = 250
line_color = (.5, .5, .5, .25)
line_kwds = {"alpha": .5, "linewidth": 1.2}
scatter_kwds = {"s": 5}
time_index = np.arange(len(data))
# Plot the time-series with labeled states and densities
# corresponding to each state's distribution.
# Create figure and subplots to be drawn on.
fig = plt.figure()
gs = plt.GridSpec(1, 2, width_ratios=[4, 1])
ax_ts = fig.add_subplot(gs[0, 0])
ax_density = fig.add_subplot(gs[0, 1], sharey=ax_ts)
# Plot time-series.
# Giving each state its own collection will cause
# matplotlib to cycle through colors from the global
# palette.
ax_ts.plot(time_index, data, c=line_color)
for k in np.unique(labels):
mask = labels == k
ax_ts.scatter(time_index[mask], data[mask], **scatter_kwds)
ax_ts.set_xlabel("Time")
# Plot densities with data on y-axis.
ymin, ymax = ax_ts.get_ybound()
ypts = np.linspace(ymin, ymax, n_points)
# Draw in the same order as the time-series
# states - matplotlib will cycle through the color
# palette in the same order since we're drawing on
# a new axis.
for k in np.unique(labels):
densities = norm.pdf(ypts, loc=means[k], scale=stds[k])
ax_density.plot(densities, ypts)
# Remove ticks from the density subplot.
remove_ticks(ax_density, xaxis=True, yaxis=True)
save_and_close(fig, "state-sequence-1.png")
# It may be of interest to show each state as it's own signal.
# Create figure and subplots to be drawn on.
fig = plt.figure()
gs = plt.GridSpec(1, 2, width_ratios=[4, 1])
ax_ts = fig.add_subplot(gs[0, 0])
ax_density = fig.add_subplot(gs[0, 1], sharey=ax_ts)
# Plot time-series.
for k in np.unique(labels):
mask = labels == k
ax_ts.plot(time_index[mask], data[mask], **line_kwds)
ax_ts.scatter(time_index[mask], data[mask], **scatter_kwds)
ax_ts.set_xlabel("Time")
# Plot densities with data on y-axis.
ymin, ymax = ax_ts.get_ybound()
ypts = np.linspace(ymin, ymax, n_points)
# Plot densities with data on y-axis.
ymin, ymax = ax_ts.get_ybound()
ypts = np.linspace(ymin, ymax, n_points)
for k in np.unique(labels):
densities = norm.pdf(ypts, loc=means[k], scale=stds[k])
ax_density.plot(densities, ypts)
# Remove ticks from the density subplot.
remove_ticks(ax_density, xaxis=True, yaxis=True)
save_and_close(fig, "state-sequence-2.png")
# Take it to the conclusion, and place each state on its
# own subplot.
cmap = plt.get_cmap("Dark2") # Not sure what ggplot uses...
# Order states by mean so that subplot order aligns with
# visual expectation (i.e. can be read top-down/bottom-up).
states = np.argsort(means)[::-1]
K = len(states)
# Create figure and subplots to be drawn on.
# One subplot for each state's time-series, and one for
# all the distributions.
fig = plt.figure()
gs = plt.GridSpec(K, 2, width_ratios=[4, 1])
axes_ts = [fig.add_subplot(gs[k, 0]) for k in range(K)]
share_axes(*axes_ts, xaxis=True, yaxis=True)
ax_density = fig.add_subplot(gs[:, 1])
# Plot each time-series.
for k, ax in zip(states, axes_ts):
mask = labels == k
colors = cmap(labels[mask])
ax.plot(time_index[mask], data[mask], c=cmap(k), **line_kwds)
ax.scatter(time_index[mask], data[mask], c=colors, **scatter_kwds)
# Leave x-ticks on the bottom-most axis in the column.
if ax is not axes_ts[-1]:
remove_ticks(ax, xaxis=True, yaxis=False)
axes_ts[-1].set_xlabel("Time")
# Draw density for each component over the entire range.
ybounds = np.array([ax.get_ybound() for ax in axes_ts])
ymin, _ = np.min(ybounds, axis=0)
_, ymax = np.max(ybounds, axis=0)
ypts = np.linspace(ymin, ymax, n_points)
for k in states:
densities = norm.pdf(ypts, loc=means[k], scale=stds[k])
ax_density.plot(densities, ypts, c=cmap(k))
# Clean up density plot, but keep y-axis labels as ticks no
# longer align with the time-series.
remove_ticks(ax_density, xaxis=True, yaxis=False)
ax_density.yaxis.tick_right()
save_and_close(fig, "state-sequence-3.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment