Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save maria-aguilera/b2430d68a9cce9728605422b36cbcbc1 to your computer and use it in GitHub Desktop.
Save maria-aguilera/b2430d68a9cce9728605422b36cbcbc1 to your computer and use it in GitHub Desktop.
Utility functions for visualization using pandas dataframes and matplotlib
import numpy
import pandas
import matplotlib.pyplot as plt
import seaborn as sns
from ggplot import *
plt.style.use('ggplot')
def get_histogram_xy(data, bins=10):
"""Returns x,y coordinates for Histogram data.
Args:
data (array-like): data to be plotted
bins (int): number of bins for histogram (defaults to 10)
Returns:
array (int): x-coordinates (bin centers)
array (int): y-coordinates (histogram peaks)
"""
yPoints, binEdges = numpy.histogram(data, bins=bins)
binCenters = (binEdges[1:] + binEdges[:-1]) / 2
return binCenters, yPoints
def get_color_list(num=2, style='ggplot'):
"""
Return color palette as a list based on number of colors.
Style can be 'ggplot' or 'seaborn'. Default to 'ggplot'.
'ggplot' colors are hardcoded and hence restricted to 10 colors.
Need to write native function to choose colors like ggplot2 in R.
"""
if style == 'seaborn':
color_vals = sns.color_palette()[:num]
elif style == 'ggplot':
if num == 1:
color_vals = ["#F8766D"]
elif num == 2:
color_vals = ["#F8766D", "#00BFC4"]
elif num == 3:
color_vals = ["#F8766D", "#00BA38", "#619CFF"]
elif num == 4:
color_vals = ["#F8766D", "#7CAE00", "#00BFC4", "#C77CFF"]
elif num == 5:
color_vals = ["#F8766D", "#A3A500", "#00BF7D", "#00B0F6", "#E76BF3"]
elif num == 6:
color_vals = ["#F8766D", "#B79F00", "#00BA38", "#00BFC4", "#619CFF", "#F564E3"]
elif num == 7:
color_vals = ["#F8766D", "#C49A00", "#53B400", "#00C094", "#00B6EB", "#A58AFF", "#FB61D7"]
elif num == 8:
color_vals = ["#F8766D", "#CD9600", "#7CAE00", "#00BE67", "#00BFC4", "#00A9FF", "#C77CFF", "#FF61CC"]
elif num == 9:
color_vals = ["#F8766D", "#D39200", "#93AA00", "#00BA38", "#00C19F", "#00B9E3", "#619CFF", "#DB72FB", "#FF61C3"]
elif num == 10:
color_vals = ["#F8766D", "#D89000", "#A3A500", "#39B600", "#00BF7D", "#00BFC4", "#00B0F6", "#9590FF", "#E76BF3", "#FF62BC"]
else:
color_vals = None
return color_vals
def get_group_by_data(data=None, xvar=None, yvar=None, cvar=None, exclude=None):
"""
Returns x, y values based on a categorical variable.
Useful for plotting colored scatter plots.
Parameters
data: dataframe
x, y, color: columns in dataframe
exclude: exclude a particular value in cvar (to be passed as an iterable)
Not well tested.
Returns
x-values: list of x-values to plot
length: number of unique `color` values (except if `exclude` is included)
y-values: list of y-values to plot
length: number of unique `color` values (except if `exclude` is included)
legends: values for `color` variable (to be used in legend)
"""
if cvar:
cvar_vals = sorted(data[cvar].unique())
if exclude:
for exc_val in exclude:
cvar_vals = [x for x in cvar_vals if x != exc_val]
if len(cvar_vals) < 1:
raise Exception('Nothing in \'cvar\' category')
xlist = []
ylist = []
legends = []
for cval in cvar_vals:
xlist.append(data[data[cvar] == cval][xvar])
ylist.append(data[data[cvar] == cval][yvar])
legends.append(cval)
return xlist, ylist, legends
def plot_scatter(data=None, xvar=None, yvar=None, cvar=None, exclude=None, ax=None, labels=True, ticks=True):
"""
Plots scatter plots across two variables, colored by a third categorical variable.
This can be easily achieved using one of the following:
``plt.scatter(var1, var2, c=var2)``
``ggplot(aes(x=var1, y=var2, color=var3), data=df) + geom_point()``
But both take a long time to plot.
Uses ``get_grouped_by_data`` to split data and ``matplotlib`` to plot.
Parameters
data: dataframe
xvar, yvar, cvar: columns in dataframe
ax: axis to make the plot on
if not provided, it will create a new figure, axis and plot
"""
if ax:
ax_passed = True
else:
ax_passed = None
if not cvar:
# no need to categorize
xlist = []
ylist = []
legends = ['None']
xlist.append(data[xvar])
ylist.append(data[yvar])
else:
# categorize and get grouped data
xlist, ylist, legends = get_group_by_data(data=data, xvar=xvar, yvar=yvar, cvar=cvar, exclude=exclude)
# get colors
color_list = get_color_list(num=len(xlist))
# create figure and ax if axes not already passed
if not ax:
fig = plt.figure()
ax = fig.add_subplot(111)
fig.set_size_inches(7, 5)
for x, y, l, c in zip(xlist, ylist, legends, color_list):
if not cvar:
ax.scatter(x, y, s=30, c=c, edgecolor=c)
else:
ax.scatter(x, y, s=30, c=c, edgecolor=c, label=l)
if labels:
ax.set_xlabel(xvar)
ax.set_ylabel(yvar)
if cvar:
ax.legend(title=cvar)
if ticks == False:
ax.xaxis.set_major_formatter(plt.NullFormatter())
ax.yaxis.set_major_formatter(plt.NullFormatter())
if ax_passed:
return ax
def get_subplots(naxes=4, layout='matrix', nrows=None, ncols=None):
"""
Create matplotlib fig and subplot axes based on layout.
Parameters
naxes: number of axes/sub-plots in the figure
layout: shape of the figure (vertical, horizontal, matrix)
(should try generalizing with GridSpec)
nrows, ncols: will override other settings
Returns
fig: the figure object
axes: matrix of axes objects in the shape of the figure
"""
if nrows and ncols:
naxes = nrows * ncols
else:
if layout == 'matrix':
nrows = int(numpy.sqrt(naxes))
ncols = nrows
elif layout == 'horizontal':
nrows = 1
ncols = naxes
elif layout == 'vertical':
nrows = naxes
ncols = 1
fig = plt.figure()
axes = [fig.add_subplot(nrows, ncols, i+1) for i in range(naxes)]
# reshape axes to take shape of the figure
axes = numpy.reshape(axes, (nrows, ncols))
return fig, axes
def get_group_by_data2(data=None, xvar=None, cvar=None, exclude=None):
"""
Similar to get_group_by_data, except that this function works for 1 variable
Need to make changes in the original function to allow this functionality.
"""
if cvar:
cvar_vals = sorted(data[cvar].unique())
xlist = []
legends = []
for cval in cvar_vals:
xlist.append(data[data[cvar] == cval][xvar])
legends.append(cval)
return xlist, legends
def plot_hist(data=None, xvar=None, bins=15, cvar=None, ax=None, labels=True, ticks=True):
"""
Plot histogram
"""
if ax:
ax_passed = True
else:
ax_passed = None
# create figure and ax if axes not already passed
if not ax:
fig = plt.figure()
ax = fig.add_subplot(111)
fig.set_size_inches(7, 5)
if cvar:
# color by a variable
xlist, legends = get_group_by_data2(data=data, xvar=xvar, cvar=cvar)
color_list = get_color_list(num=len(xlist))
for x, l, c in zip(xlist, legends, color_list):
ax.hist(x, bins=bins, color=c, label=str(l))
else:
ax.hist(data[xvar], bins=bins, color=get_color_list(1)[0])
if labels:
ax.set_xlabel(xvar)
ax.set_ylabel('frequency')
if cvar:
ax.legend(title=cvar)
if ticks == False:
ax.xaxis.set_major_formatter(plt.NullFormatter())
ax.yaxis.set_major_formatter(plt.NullFormatter())
if ax_passed:
return ax
def plot_pairs(data=None, vars=None, cvar=None):
"""
Function to plot bivariate relationships across provided `vars`.
Parameters
data: dataframe
vars: list of columns to plot
cvar: column to group data
Not well tested.
Things to add:
- Get off-diagonal grid lines on diagonal plots and get corresponding yticklabels
"""
n = len(vars)
fig, axes = get_subplots(naxes=n*n, layout='matrix')
for i, yvar in zip(range(n), vars):
for j, xvar in zip(range(n), vars):
ax = axes[i, j]
ax.ticklabel_format(axis='both', style='sci', scilimits=(-3, 3))
if i == j:
plot_hist(data=data, xvar=yvar, cvar=cvar, ax=ax, labels=False, ticks=True)
else:
plot_scatter(data=data, xvar=xvar, yvar=yvar, cvar=cvar, ax=ax, labels=False, ticks=True)
# x, y labels and ticks
if j == 0:
ax.set_ylabel(yvar)
else:
ax.yaxis.set_major_formatter(plt.NullFormatter())
if i == n-1:
ax.set_xlabel(xvar)
else:
ax.xaxis.set_major_formatter(plt.NullFormatter())
# reset diagonal xlims
for i in range(n):
if i == 0:
xlims = axes[i+1][i].get_xlim()
else:
xlims = axes[i-1][i].get_xlim()
axes[i][i].set_xlim(xlims)
axes[i][i].grid(b=False)
# reset yticks
# ytl = axes[0][1].get_yticks()
# axes[0][0].set_yticks(ytl)
# axes[0][0].grid(True)
fig.subplots_adjust(wspace=0.05, hspace=0.05)
return fig
def plot_pairs_sns(data=None, vars=None, cvar=None):
"""
Plot pair-wise relationships using seaborn.
"""
ncolors = len(data[cvar].unique())
sns.set_palette(get_color_list(num=ncolors))
fig = sns.PairGrid(data, vars=vars, hue=cvar)
fig = fig.map_diag(plt.hist)
fig = fig.map_lower(plt.scatter)
fig = fig.map_upper(plt.scatter)
fig = fig.add_legend()
return fig
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment