Last active March 12, 2023 20:48
Density scatter plot with marginals
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
import matplotlib.ticker as plticker
rng = np.random.default_rng(42)
def generate_dist(mu=[0, 0], sigma1=1, sigma2=2, rho=0.5, n_points=250):
Generate sample from bivariate normal dist
:return: x, y lists of coordinates
x, y = rng.multivariate_normal(mu,
[[sigma1 ** 2, rho * sigma1 * sigma2],
[rho * sigma1 * sigma2, sigma2 ** 2]],
return x, y
def kde(x, y, xmin, xmax, ymin, ymax, lim=1e-3):
Compute KDE on the plot space
:param lim: minimal kde value to present, anything lower goes to 0
:return: x grid, y grid, kde evaluated at each point
xx, yy = np.mgrid[xmin:xmax:500j, ymin:ymax:500j]
positions = np.vstack([xx.ravel(), yy.ravel()])
values = np.vstack([x, y])
kernel = st.gaussian_kde(values)
f = np.reshape(kernel(positions).T, xx.shape)
f[f < lim] = 0
return xx, yy, f
def boxplot_annot(x1, x2, y1, y2, text, ax, d=0.1, vert=True):
Put annotation on boxplot
if vert:
ax.plot([x1, x1, x2, x2],
[y1 + d, np.max([y1, y2]) + d * 2, np.max([y1, y2]) + d * 2, y2 + d],
c='k', lw=1)
ax.text((x1 + x2) / 2, np.max([y1, y2]) + d * 4, text,
ax.plot([y1 + d, np.max([y1, y2]) + d * 2, np.max([y1, y2]) + d * 2, y2 + d],
[x1, x1, x2, x2], c='k', lw=1)
ax.text(np.max([y1, y2]) + d * 3, (x1 + x2) / 2, text,
def boxplot(dat, xmin, xmax, annot=None, colours=None, ax=None, vert=True):
Display boxplots of marginal data
pos = (np.arange(len(dat)) + 0.5) * (xmax - xmin) / len(dat)
bplots = []
for i in range(len(dat)):
bplots.append(ax.boxplot(dat[i], positions=[pos[i]],
widths=(xmax - xmin) / len(dat) * 0.6,
# it's madness, but i don't know the way to get the state of colour cycler without changing it
if colours is None:
colours = []
for i in range(len(dat)):
sc = ax.scatter([], [])
colours.append([list(sc.get_facecolor()[0][:-1]) + [1]][0])
for bp, colour in zip(bplots, colours):
if annot is not None:
xi1, xi2, val = annot
boxplot_annot(pos[xi1], pos[xi2], np.max(dat[xi1]), np.max(dat[xi2]),
val, ax=ax, d=0.05 * ax.get_ylim()[1], vert=vert)
def mainplot(dat, xmin, xmax, ymin, ymax, colours=None, ax=None):
Display 2 dimensional distribution of data points with KDE contours
if ax is None:
ax = plt.gca()
for i, d in enumerate(dat):
if colours is not None:
c = colours[i]
c = None
sc = ax.scatter(*d, s=15, alpha=0.5, c=c)
cl = [list(sc.get_facecolors()[0][:-1]) + [1]]
xx, yy, f = kde(*d, xmin, xmax, ymin, ymax)
ax.contour(xx, yy, f, 5, colors=cl)
def plot_legend(dat, labels, colours=None, ax=None):
Construct and display legends
if ax is None:
ax = plt.gca()
handels = []
for i in range(len(dat)):
if colours is not None:
c = colours[i]
c = None
element, = ax.plot([], [], '-o', c=c, label=labels[i])
ax.legend(handles=handels, loc='center', frameon=False)
def set_style(axs, loc_base=10):
Standard style set, removes part of boxes, set locator to a standard base
axs[0, 0].spines.right.set_visible(False)
axs[0, 0]
axs[0, 0].tick_params(bottom=False, top=False,
axs[0, 1].spines.right.set_visible(False)
axs[0, 1]
loc = plticker.MultipleLocator(base=loc_base)
axs[0, 1].xaxis.set_major_locator(loc)
axs[0, 1].yaxis.set_major_locator(loc)
axs[1, 1]
axs[1, 1].spines.right.set_visible(False)
axs[1, 1].tick_params(left=False, right=False,
axs[1, 0].axis('off')
# generate data
x1, y1 = generate_dist(mu=[16, 28], sigma1=2, sigma2=2.5, rho=0.7, n_points=30)
x2, y2 = generate_dist(mu=[24, 40], sigma1=3, sigma2=5, rho=0.5, n_points=200)
x3, y3 = generate_dist(mu=[14, 20], sigma1=3, sigma2=3, rho=0.5, n_points=50)
x1 = np.concatenate([x1, x2])
y1 = np.concatenate([y1, y2])
# define plot range
xmin, xmax = 1, 65
ymin, ymax = 1, 65
colours = ['cornflowerblue', 'darkorange'] # None if don't want to specify
data_list = [(x1, y1), (x3, y3)]
# plot the plot
fig, axs = plt.subplots(2, 2, figsize=(6, 4), dpi=150,
sharex=True, sharey=True,
gridspec_kw={'width_ratios': [1, 5],
'height_ratios': [5, 1]})
# set axis style
# main plot
mainplot(data_list, xmin, xmax, ymin, ymax, ax=axs[0, 1], colours=colours)
# box plot
boxplot([y1, y3],
ymin, ymax, ax=axs[0, 0], annot=[0, 1, r'$P < 2\times10^{-16}$'],
colours=colours, vert=True)
boxplot([x1, x3],
xmin, xmax, ax=axs[1, 1], annot=[0, 1, r'$P < 2\times10^{-16}$'],
colours=colours, vert=False)
# plot legend
plot_legend(data_list, labels=['Dist 1', 'Dist 2'], ax=axs[1, 0], colours=colours)
# general labels
# fig.suptitle('Plot implemented in pure matplotlib')
axs[1, 1].set_xlabel('Values on x-axis')
axs[0, 0].set_ylabel('Values on y-axis')
plt.xlim(xmin, xmax)
plt.ylim(ymin, ymax)
