Created
August 11, 2021 15:16
-
-
Save michaelsilverstein/7f857dd6a2768a393613f6f8e13b5f59 to your computer and use it in GitHub Desktop.
Stacked barplot function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def stackedbarplot(data, stack_order=None, palette=None, **barplot_kws): | |
""" | |
Create a stacked barplot | |
Inputs: | |
| data <pd.DataFrame>: A wideform dataframe where the index is the variable to stack, the columns are different samples (x-axis), and the cells the counts (y-axis) | |
| stack_order <array-like>: The order for bars to be stacked (Default: given order) | |
| palette <array-like>: The colors to use for each value of `stack_order` (Default: husl) | |
| barplot_kws: Arguments to pass to sns.barplot() | |
Author: Michael Silverstein | |
Usage: https://github.com/michaelsilverstein/Pandas-and-Plotting/blob/master/lessons/stacked_bar_chart.ipynb | |
""" | |
# Order df | |
if stack_order is None: | |
stack_order = data.index | |
# Create palette if none | |
if palette is None: | |
palette = dict(zip(stack_order, sns.husl_palette(len(stack_order)))) | |
# Compute cumsum | |
cumsum = data.loc[stack_order].cumsum() | |
# Melt for passing to seaborn | |
cumsum_stacked = cumsum.stack().reset_index(name='count') | |
# Get name of variable to stack and sample | |
stack_name, sample_name = cumsum_stacked.columns[:2] | |
# Plot bar plot | |
for s in stack_order[::-1]: | |
# Subset to this stack level | |
d = cumsum_stacked[cumsum_stacked[stack_name].eq(s)] | |
sns.barplot(x=sample_name, y='count', hue=stack_name, palette=palette, data=d, **barplot_kws) | |
return plt.gca() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment