Created
December 25, 2019 00:55
-
-
Save Per48edjes/1f39f3e3db10ae96d70e13c12228402b to your computer and use it in GitHub Desktop.
Function to plot a stacked bar chart
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
def plot_stacked_bar(data, series_labels, category_labels=None, | |
show_values=False, value_format="{}", y_label=None, | |
colors=None, grid=True, reverse=False): | |
"""Plots a stacked bar chart with the data and labels provided. | |
Keyword arguments: | |
data -- 2-dimensional numpy array or nested list | |
containing data for each series in rows | |
series_labels -- list of series labels (these appear in | |
the legend) | |
category_labels -- list of category labels (these appear | |
on the x-axis) | |
show_values -- If True then numeric value labels will | |
be shown on each bar | |
value_format -- Format string for numeric value labels | |
(default is "{}") | |
y_label -- Label for y-axis (str) | |
colors -- List of color labels | |
grid -- If True display grid | |
reverse -- If True reverse the order that the | |
series are displayed (left-to-right | |
or right-to-left) | |
""" | |
ny = len(data[0]) | |
ind = list(range(ny)) | |
axes = [] | |
cum_size = np.zeros(ny) | |
data = np.array(data) | |
if reverse: | |
data = np.flip(data, axis=1) | |
category_labels = reversed(category_labels) | |
for i, row_data in enumerate(data): | |
axes.append(plt.bar(ind, row_data, bottom=cum_size, | |
label=series_labels[i], color=colors[i])) | |
cum_size += row_data | |
if category_labels: | |
plt.xticks(ind, category_labels) | |
if y_label: | |
plt.ylabel(y_label) | |
plt.legend() | |
if grid: | |
plt.grid() | |
if show_values: | |
for axis in axes: | |
for bar in axis: | |
w, h = bar.get_width(), bar.get_height() | |
plt.text(bar.get_x() + w/2, bar.get_y() + h/2, | |
value_format.format(h), ha="center", | |
va="center") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment