Skip to content

Instantly share code, notes, and snippets.

@kvalv
Created October 31, 2018 17:42
Show Gist options
  • Save kvalv/3a3912bcea9236ed57ae1e78bdad2a29 to your computer and use it in GitHub Desktop.
Save kvalv/3a3912bcea9236ed57ae1e78bdad2a29 to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import numpy as np
import seaborn
def stack_plot(stacks, bars, xlabels=None, title=None):
'''
stacks: np.array of shape (N, K)
bars: np.array of shape (N,)
'''
fig = plt.gcf()
ax = fig.gca()
N, K = stacks.shape
# new shape is (N, K+1) where the leftmost column is zero's.
stacks_with_zeros = np.hstack((np.zeros(N).reshape((-1, 1)), stacks))
# find a suitable color palette.. -- one for each color..
# color_palette = seaborn.color_palette(n_colors=K)
color_palette = seaborn.cubehelix_palette(K)
# x-locations, one for each of the N bars
x = np.arange(0, N, dtype=int)
# first, lets make the stacks -- add a small bar for each of the K classes
for i, c in zip(range(0, K), color_palette):
ax.bar(x, height=stacks_with_zeros[:, i+1],
width=0.25,
bottom=stacks_with_zeros[:, i],
color=c)
# we add a small offset for the other data.
offset = 0.30
for i in range(0, N):
ax.bar(x + offset, height=bars, width=0.25, color='k') # k is black color
ax.set_xticks(x)
if xlabels is None:
xlabels = list(map(str, x))
ax.set_xticklabels(xlabels)
# make it look good..
seaborn.despine()
fig.autofmt_xdate()
if title is not None:
ax.set_title(title)
N = 10 # we have 10 kindergardens
K = 4 # ages 2, 3, 4, 5 --- so 4 different ages
stacks = np.random.randint(1, 100, (N, K))
bars = np.random.randint(1, 100, N)
xlabels = [f'kindergarden {i}' for i in range(N)]
stack_plot(stacks, bars, xlabels, 'Kids should do drugs')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment