Forked from vignesh-saptarishi/mpl_pandas_plot_tools.py
Created
November 16, 2022 16:02
-
-
Save maria-aguilera/b2430d68a9cce9728605422b36cbcbc1 to your computer and use it in GitHub Desktop.
Utility functions for visualization using pandas dataframes and matplotlib
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy | |
import pandas | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from ggplot import * | |
plt.style.use('ggplot') | |
def get_histogram_xy(data, bins=10): | |
"""Returns x,y coordinates for Histogram data. | |
Args: | |
data (array-like): data to be plotted | |
bins (int): number of bins for histogram (defaults to 10) | |
Returns: | |
array (int): x-coordinates (bin centers) | |
array (int): y-coordinates (histogram peaks) | |
""" | |
yPoints, binEdges = numpy.histogram(data, bins=bins) | |
binCenters = (binEdges[1:] + binEdges[:-1]) / 2 | |
return binCenters, yPoints | |
def get_color_list(num=2, style='ggplot'): | |
""" | |
Return color palette as a list based on number of colors. | |
Style can be 'ggplot' or 'seaborn'. Default to 'ggplot'. | |
'ggplot' colors are hardcoded and hence restricted to 10 colors. | |
Need to write native function to choose colors like ggplot2 in R. | |
""" | |
if style == 'seaborn': | |
color_vals = sns.color_palette()[:num] | |
elif style == 'ggplot': | |
if num == 1: | |
color_vals = ["#F8766D"] | |
elif num == 2: | |
color_vals = ["#F8766D", "#00BFC4"] | |
elif num == 3: | |
color_vals = ["#F8766D", "#00BA38", "#619CFF"] | |
elif num == 4: | |
color_vals = ["#F8766D", "#7CAE00", "#00BFC4", "#C77CFF"] | |
elif num == 5: | |
color_vals = ["#F8766D", "#A3A500", "#00BF7D", "#00B0F6", "#E76BF3"] | |
elif num == 6: | |
color_vals = ["#F8766D", "#B79F00", "#00BA38", "#00BFC4", "#619CFF", "#F564E3"] | |
elif num == 7: | |
color_vals = ["#F8766D", "#C49A00", "#53B400", "#00C094", "#00B6EB", "#A58AFF", "#FB61D7"] | |
elif num == 8: | |
color_vals = ["#F8766D", "#CD9600", "#7CAE00", "#00BE67", "#00BFC4", "#00A9FF", "#C77CFF", "#FF61CC"] | |
elif num == 9: | |
color_vals = ["#F8766D", "#D39200", "#93AA00", "#00BA38", "#00C19F", "#00B9E3", "#619CFF", "#DB72FB", "#FF61C3"] | |
elif num == 10: | |
color_vals = ["#F8766D", "#D89000", "#A3A500", "#39B600", "#00BF7D", "#00BFC4", "#00B0F6", "#9590FF", "#E76BF3", "#FF62BC"] | |
else: | |
color_vals = None | |
return color_vals | |
def get_group_by_data(data=None, xvar=None, yvar=None, cvar=None, exclude=None): | |
""" | |
Returns x, y values based on a categorical variable. | |
Useful for plotting colored scatter plots. | |
Parameters | |
data: dataframe | |
x, y, color: columns in dataframe | |
exclude: exclude a particular value in cvar (to be passed as an iterable) | |
Not well tested. | |
Returns | |
x-values: list of x-values to plot | |
length: number of unique `color` values (except if `exclude` is included) | |
y-values: list of y-values to plot | |
length: number of unique `color` values (except if `exclude` is included) | |
legends: values for `color` variable (to be used in legend) | |
""" | |
if cvar: | |
cvar_vals = sorted(data[cvar].unique()) | |
if exclude: | |
for exc_val in exclude: | |
cvar_vals = [x for x in cvar_vals if x != exc_val] | |
if len(cvar_vals) < 1: | |
raise Exception('Nothing in \'cvar\' category') | |
xlist = [] | |
ylist = [] | |
legends = [] | |
for cval in cvar_vals: | |
xlist.append(data[data[cvar] == cval][xvar]) | |
ylist.append(data[data[cvar] == cval][yvar]) | |
legends.append(cval) | |
return xlist, ylist, legends | |
def plot_scatter(data=None, xvar=None, yvar=None, cvar=None, exclude=None, ax=None, labels=True, ticks=True): | |
""" | |
Plots scatter plots across two variables, colored by a third categorical variable. | |
This can be easily achieved using one of the following: | |
``plt.scatter(var1, var2, c=var2)`` | |
``ggplot(aes(x=var1, y=var2, color=var3), data=df) + geom_point()`` | |
But both take a long time to plot. | |
Uses ``get_grouped_by_data`` to split data and ``matplotlib`` to plot. | |
Parameters | |
data: dataframe | |
xvar, yvar, cvar: columns in dataframe | |
ax: axis to make the plot on | |
if not provided, it will create a new figure, axis and plot | |
""" | |
if ax: | |
ax_passed = True | |
else: | |
ax_passed = None | |
if not cvar: | |
# no need to categorize | |
xlist = [] | |
ylist = [] | |
legends = ['None'] | |
xlist.append(data[xvar]) | |
ylist.append(data[yvar]) | |
else: | |
# categorize and get grouped data | |
xlist, ylist, legends = get_group_by_data(data=data, xvar=xvar, yvar=yvar, cvar=cvar, exclude=exclude) | |
# get colors | |
color_list = get_color_list(num=len(xlist)) | |
# create figure and ax if axes not already passed | |
if not ax: | |
fig = plt.figure() | |
ax = fig.add_subplot(111) | |
fig.set_size_inches(7, 5) | |
for x, y, l, c in zip(xlist, ylist, legends, color_list): | |
if not cvar: | |
ax.scatter(x, y, s=30, c=c, edgecolor=c) | |
else: | |
ax.scatter(x, y, s=30, c=c, edgecolor=c, label=l) | |
if labels: | |
ax.set_xlabel(xvar) | |
ax.set_ylabel(yvar) | |
if cvar: | |
ax.legend(title=cvar) | |
if ticks == False: | |
ax.xaxis.set_major_formatter(plt.NullFormatter()) | |
ax.yaxis.set_major_formatter(plt.NullFormatter()) | |
if ax_passed: | |
return ax | |
def get_subplots(naxes=4, layout='matrix', nrows=None, ncols=None): | |
""" | |
Create matplotlib fig and subplot axes based on layout. | |
Parameters | |
naxes: number of axes/sub-plots in the figure | |
layout: shape of the figure (vertical, horizontal, matrix) | |
(should try generalizing with GridSpec) | |
nrows, ncols: will override other settings | |
Returns | |
fig: the figure object | |
axes: matrix of axes objects in the shape of the figure | |
""" | |
if nrows and ncols: | |
naxes = nrows * ncols | |
else: | |
if layout == 'matrix': | |
nrows = int(numpy.sqrt(naxes)) | |
ncols = nrows | |
elif layout == 'horizontal': | |
nrows = 1 | |
ncols = naxes | |
elif layout == 'vertical': | |
nrows = naxes | |
ncols = 1 | |
fig = plt.figure() | |
axes = [fig.add_subplot(nrows, ncols, i+1) for i in range(naxes)] | |
# reshape axes to take shape of the figure | |
axes = numpy.reshape(axes, (nrows, ncols)) | |
return fig, axes | |
def get_group_by_data2(data=None, xvar=None, cvar=None, exclude=None): | |
""" | |
Similar to get_group_by_data, except that this function works for 1 variable | |
Need to make changes in the original function to allow this functionality. | |
""" | |
if cvar: | |
cvar_vals = sorted(data[cvar].unique()) | |
xlist = [] | |
legends = [] | |
for cval in cvar_vals: | |
xlist.append(data[data[cvar] == cval][xvar]) | |
legends.append(cval) | |
return xlist, legends | |
def plot_hist(data=None, xvar=None, bins=15, cvar=None, ax=None, labels=True, ticks=True): | |
""" | |
Plot histogram | |
""" | |
if ax: | |
ax_passed = True | |
else: | |
ax_passed = None | |
# create figure and ax if axes not already passed | |
if not ax: | |
fig = plt.figure() | |
ax = fig.add_subplot(111) | |
fig.set_size_inches(7, 5) | |
if cvar: | |
# color by a variable | |
xlist, legends = get_group_by_data2(data=data, xvar=xvar, cvar=cvar) | |
color_list = get_color_list(num=len(xlist)) | |
for x, l, c in zip(xlist, legends, color_list): | |
ax.hist(x, bins=bins, color=c, label=str(l)) | |
else: | |
ax.hist(data[xvar], bins=bins, color=get_color_list(1)[0]) | |
if labels: | |
ax.set_xlabel(xvar) | |
ax.set_ylabel('frequency') | |
if cvar: | |
ax.legend(title=cvar) | |
if ticks == False: | |
ax.xaxis.set_major_formatter(plt.NullFormatter()) | |
ax.yaxis.set_major_formatter(plt.NullFormatter()) | |
if ax_passed: | |
return ax | |
def plot_pairs(data=None, vars=None, cvar=None): | |
""" | |
Function to plot bivariate relationships across provided `vars`. | |
Parameters | |
data: dataframe | |
vars: list of columns to plot | |
cvar: column to group data | |
Not well tested. | |
Things to add: | |
- Get off-diagonal grid lines on diagonal plots and get corresponding yticklabels | |
""" | |
n = len(vars) | |
fig, axes = get_subplots(naxes=n*n, layout='matrix') | |
for i, yvar in zip(range(n), vars): | |
for j, xvar in zip(range(n), vars): | |
ax = axes[i, j] | |
ax.ticklabel_format(axis='both', style='sci', scilimits=(-3, 3)) | |
if i == j: | |
plot_hist(data=data, xvar=yvar, cvar=cvar, ax=ax, labels=False, ticks=True) | |
else: | |
plot_scatter(data=data, xvar=xvar, yvar=yvar, cvar=cvar, ax=ax, labels=False, ticks=True) | |
# x, y labels and ticks | |
if j == 0: | |
ax.set_ylabel(yvar) | |
else: | |
ax.yaxis.set_major_formatter(plt.NullFormatter()) | |
if i == n-1: | |
ax.set_xlabel(xvar) | |
else: | |
ax.xaxis.set_major_formatter(plt.NullFormatter()) | |
# reset diagonal xlims | |
for i in range(n): | |
if i == 0: | |
xlims = axes[i+1][i].get_xlim() | |
else: | |
xlims = axes[i-1][i].get_xlim() | |
axes[i][i].set_xlim(xlims) | |
axes[i][i].grid(b=False) | |
# reset yticks | |
# ytl = axes[0][1].get_yticks() | |
# axes[0][0].set_yticks(ytl) | |
# axes[0][0].grid(True) | |
fig.subplots_adjust(wspace=0.05, hspace=0.05) | |
return fig | |
def plot_pairs_sns(data=None, vars=None, cvar=None): | |
""" | |
Plot pair-wise relationships using seaborn. | |
""" | |
ncolors = len(data[cvar].unique()) | |
sns.set_palette(get_color_list(num=ncolors)) | |
fig = sns.PairGrid(data, vars=vars, hue=cvar) | |
fig = fig.map_diag(plt.hist) | |
fig = fig.map_lower(plt.scatter) | |
fig = fig.map_upper(plt.scatter) | |
fig = fig.add_legend() | |
return fig |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment