-
-
Save mikepsn/a741ead5054ae995660817c27b40b765 to your computer and use it in GitHub Desktop.
A collection of univariate plots
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
from textwrap import fill | |
from scipy.stats import norm, uniform, skewnorm, gaussian_kde, triang | |
from numpy import ( | |
array, linspace, quantile, histogram, atleast_2d, mean, std, add | |
) | |
from numpy.lib.stride_tricks import sliding_window_view | |
from matplotlib.pyplot import subplots, show, rc | |
from matplotlib.axes import Axes | |
import seaborn as sns | |
rc('font', size=14) | |
rc('axes.spines', top=False, right=False, left=False, bottom=False) | |
dists = [ | |
norm(loc=10, scale=2), | |
uniform(loc=0, scale=20), | |
skewnorm(a=6, loc=10, scale=2), | |
triang(c=1, loc=5, scale=7), | |
] | |
samples = [d.rvs(size=200, random_state=0) for d in dists] | |
def tufte_quartiles(ax, data): | |
q = quantile(data, [0, .25, .5, .75, 1]) | |
ax.hlines([0, 0], [q[0], q[3]], [q[1], q[4]]) | |
ax.scatter([q[2]], [0]) | |
def color_density(ax, data): | |
grid = linspace(data.min(), data.max(), 400) | |
densities = gaussian_kde(data)(grid) | |
densities = atleast_2d(densities).repeat(2, axis=0) | |
ax.pcolormesh(grid, [0, 1], densities, cmap='Blues') | |
def point_decile(ax, data): | |
d = quantile(data, linspace(0, 1, 11)) | |
linewidths = array([1, 2, 3, 4, 5, 5, 4, 3, 2, 1]) * 4 | |
bounds = sliding_window_view(d, 2) | |
ax.hlines( | |
[0] * len(bounds), bounds[:, 0], bounds[:, 1], linewidths=linewidths | |
) | |
ax.scatter(d[5], 0, color='white', zorder=7) | |
def point_multi_sigmas(ax, data): | |
linewidths = array([1, 2, 3, 2, 1]) * 5 | |
avg, sd = mean(data), std(data) | |
sigmas = (sd * array([-3, -2, -1, 1, 2, 3])) | |
bounds = sliding_window_view(sigmas + avg, 2) | |
ax.hlines( | |
[0] * len(bounds), bounds[:, 0], bounds[:, 1], linewidths=linewidths | |
) | |
ax.scatter(avg, 0, color='white', zorder=7) | |
univariate_funcs = [ | |
('strip', partial(sns.stripplot, jitter=.3, ec='white', size=3)), | |
('swarm', partial(sns.swarmplot, size=3)), | |
('rug', partial(Axes.eventplot, alpha=.4)), | |
('kernel density (area)', partial(sns.kdeplot, fill=True)), | |
('kernel density (color)', color_density), | |
('cumulative KDE', partial(sns.kdeplot, cumulative=True)), | |
('empirical CDF', sns.ecdfplot), | |
('histogram', partial(sns.histplot, bins='auto')), | |
('Box', sns.boxplot), | |
('Boxen', sns.boxenplot), | |
('Tufte Quartile', tufte_quartiles), | |
(r'Point $\bar{x}\pm\sigma$', partial(sns.pointplot, orient='h', errorbar='sd')), | |
('Point Deciles', point_decile), | |
(r'Point $\bar{x}\pm$ 3$\sigma$,2$\sigma$,1$\sigma$', point_multi_sigmas), | |
] | |
gridspec_kw = dict(hspace=.1, wspace=.02, left=.15, right=.9, bottom=.05) | |
fig, axes = subplots( | |
len(univariate_funcs) + 1, len(dists), | |
sharey='row', sharex='col', | |
figsize=(16, 12), gridspec_kw=gridspec_kw, | |
dpi=106 | |
) | |
for ax, d in zip(axes[0], dists): | |
grid = linspace(*d.ppf([.001, .999]), 400) | |
y = d.pdf(grid) | |
ax.plot(grid, y) | |
ax.fill_between(grid, y, alpha=.4) | |
ax.set_title( | |
f"{d.dist.name.title()}\n" | |
f"{', '.join('='.join(map(str, t)) for t in d.kwds.items())}" | |
) | |
for i, (name, func) in enumerate(univariate_funcs, start=1): | |
if isinstance(func, partial): | |
func, args, kwargs = func.func, func.args, func.keywords | |
else: | |
args, kwargs = tuple(), {} | |
for j, s in enumerate(samples): | |
ax = axes[i, j] | |
package, _, _ = func.__module__.partition('.') | |
if package == 'seaborn': | |
func(x=s, ax=ax, **kwargs) | |
else: | |
func(ax, s, *args , **kwargs) | |
if ax in axes[:, 0]: | |
name = ' '.join(n if n.isupper() else n.capitalize() for n in name.split()) | |
name = fill(name, width=20, break_long_words=False) | |
ax.set_ylabel(name, rotation=0, ha='right', va='center') | |
for ax in axes.flat: | |
ax.yaxis.set_tick_params(length=0, width=0, labelleft=False) | |
for ax in axes[:-1, :].flat: | |
ax.xaxis.set_tick_params(length=0, width=0, labelbottom=False) | |
header_bbox = axes[0, 0].get_position() | |
row_bbox = axes[1, 0].get_position() | |
from matplotlib.lines import Line2D | |
sepline = Line2D( | |
[.1, .9], [(header_bbox.y0 - row_bbox.y1) / 2 +row_bbox.y1] * 2, | |
color='k' | |
) | |
fig.add_artist(sepline) | |
gs = fig.axes[0].get_gridspec() | |
centered = (gs.right - gs.left) / 2 + gs.left | |
fig.text( | |
x=centered, y=.98, s='A Collection of Univariate Plots', | |
fontsize='xx-large', va='top', ha='center' | |
) | |
# show() | |
fig.savefig('univariateplots.png') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment