Skip to content

Instantly share code, notes, and snippets.

@FanchenBao
Created October 11, 2022 05:41
Show Gist options
  • Save FanchenBao/ce9c07d0aa45a259e531fa7f892270bc to your computer and use it in GitHub Desktop.
Save FanchenBao/ce9c07d0aa45a259e531fa7f892270bc to your computer and use it in GitHub Desktop.
A function to generate grouped bar plot. Read the doc string for more information
def grouped_barplot(
ax,
df,
xlabel: str = '',
ylabel: str = '',
title: str = '',
width: float = 0.6,
loc: str = '',
bbox_to_anchor: List = [],
):
"""Generate a grouped barplot such as the one shown in the
[doc](https://matplotlib.org/stable/gallery/lines_bars_and_markers/barchart.html)
:param ax: The axis where the grouped barplot is to be created.
:param df: Plotting data as a pandas dataframe. This dataframe must follow
these rules. 1. It must have an index listing all x-axis ticks. 2. It
must have each column corresponding to one of the bars in the final
barplot. For instance, if index = ['day1', 'day2', 'day3'], columns
= ['treatment1', 'treatment2', 'treatment3', 'control'], and each entry
is the mean value of applying the treatment specified in the column on
the day specified in index, then the final output plot will contain
four bars per group, each corresponding to a column, and three groups
in total, each corresponding to an index.
:param xlabel: The label for the x-axis. Default to empty string, i.e. no
xlabel.
:param ylabel: The label for the y-axis. Default to empty string, i.e. no
ylabel.
:param title: Title for the current axis. Default to empty string, i.e., no
title.
:param width: The width of each individual bar. Default to 0.6.
:param loc: Location of the legend box. This is the same 'loc' argument as in
[matplotlib.legend.Legend](https://matplotlib.org/stable/api/legend_api.html#matplotlib.legend.Legend)
Default to empty string, i.e. no customization of the legend box.
:param bbox_to_anchor: The position of the legend box. This is the same 'bbox_to_anchor'
argument as in [matplotlib.legend.Legend](https://matplotlib.org/stable/api/legend_api.html#matplotlib.legend.Legend)
Default to empty string, i.e. no customization of the legend box.
"""
num_bars = df.columns.shape[0]
# Position shift (to the right) between adjacent bars
deltas = np.arange(-(num_bars - 1) / 2, (num_bars - 1) / 2 + 1, 1)
# The positions where the labels show on the x-axis
x = np.arange(0, num_bars * df.index.shape[0], num_bars)
for col, delta in zip(df.columns, deltas):
ax.bar(
x + delta * width,
df.loc[:, col],
width,
label=col,
)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
ax.set_xticks(x)
ax.set_xticklabels(df.index.astype(str))
if loc and bbox_to_anchor:
ax.legend(loc=loc, bbox_to_anchor=bbox_to_anchor)
else:
ax.legend()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment