Skip to content

Instantly share code, notes, and snippets.

@michaelsilverstein
Created August 11, 2021 15:16
Show Gist options
  • Save michaelsilverstein/7f857dd6a2768a393613f6f8e13b5f59 to your computer and use it in GitHub Desktop.
Save michaelsilverstein/7f857dd6a2768a393613f6f8e13b5f59 to your computer and use it in GitHub Desktop.
Stacked barplot function
def stackedbarplot(data, stack_order=None, palette=None, **barplot_kws):
"""
Create a stacked barplot
Inputs:
| data <pd.DataFrame>: A wideform dataframe where the index is the variable to stack, the columns are different samples (x-axis), and the cells the counts (y-axis)
| stack_order <array-like>: The order for bars to be stacked (Default: given order)
| palette <array-like>: The colors to use for each value of `stack_order` (Default: husl)
| barplot_kws: Arguments to pass to sns.barplot()
Author: Michael Silverstein
Usage: https://github.com/michaelsilverstein/Pandas-and-Plotting/blob/master/lessons/stacked_bar_chart.ipynb
"""
# Order df
if stack_order is None:
stack_order = data.index
# Create palette if none
if palette is None:
palette = dict(zip(stack_order, sns.husl_palette(len(stack_order))))
# Compute cumsum
cumsum = data.loc[stack_order].cumsum()
# Melt for passing to seaborn
cumsum_stacked = cumsum.stack().reset_index(name='count')
# Get name of variable to stack and sample
stack_name, sample_name = cumsum_stacked.columns[:2]
# Plot bar plot
for s in stack_order[::-1]:
# Subset to this stack level
d = cumsum_stacked[cumsum_stacked[stack_name].eq(s)]
sns.barplot(x=sample_name, y='count', hue=stack_name, palette=palette, data=d, **barplot_kws)
return plt.gca()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment