Skip to content

Instantly share code, notes, and snippets.

@philastrophist
Last active April 26, 2022 13:07
Show Gist options
  • Save philastrophist/b0e835f7cb71d7b0319c0d7c30a3fe30 to your computer and use it in GitHub Desktop.
Save philastrophist/b0e835f7cb71d7b0319c0d7c30a3fe30 to your computer and use it in GitHub Desktop.
HPD estimation
#! encoding=<utf8>
from __future__ import division
import numpy as np
from fastkde import fastKDE
from warnings import warn
__all__ = ['find_mode']
def make_indices(dimensions):
# Generates complete set of indices for given dimensions
level = len(dimensions)
if level == 1:
return list(range(dimensions[0]))
indices = [[]]
while level:
_indices = []
for j in range(dimensions[level - 1]):
_indices += [[j] + i for i in indices]
indices = _indices
level -= 1
try:
return [tuple(i) for i in indices]
except TypeError:
return indices
def calc_min_interval(x, cred_mass):
"""Internal method to determine the minimum interval of
a given width
Assumes that x is sorted numpy array.
credit: pymc3
"""
n = len(x)
interval_idx_inc = int(np.floor(cred_mass * n))
n_intervals = n - interval_idx_inc
interval_width = x[interval_idx_inc:] - x[:n_intervals]
if len(interval_width) == 0:
raise ValueError('Too few elements for interval calculation')
min_idx = np.argmin(interval_width)
hdi_min = x[min_idx]
hdi_max = x[min_idx + interval_idx_inc]
return hdi_min, hdi_max
def calc_hpd(x, cred_mass=0.68, transform=lambda x: x):
"""Calculate highest posterior density (HPD) of array for given credible interval mass. The HPD is the
minimum width Bayesian credible interval (BCI).
:Arguments:
x : Numpy array
An array containing MCMC samples
cred_mass : float
Desired credible interval probability mass
transform : callable
Function to transform data (defaults to identity)
"""
# Make a copy of trace
x = transform(x.copy())
# For multivariate node
if x.ndim > 1:
# Transpose first, then sort
tx = np.transpose(x, list(range(x.ndim))[1:] + [0])
dims = np.shape(tx)
# Container list for intervals
intervals = np.resize(0.0, dims[:-1] + (2,))
for index in make_indices(dims[:-1]):
try:
index = tuple(index)
except TypeError:
pass
# Sort trace
sx = np.sort(tx[index])
# Append to list
intervals[index] = calc_min_interval(sx, cred_mass)
# Transpose back before returning
return np.array(intervals)
else:
# Sort univariate node
sx = np.sort(x)
return np.array(calc_min_interval(sx, cred_mass))
def find_mode(trace, credible_interval_mass=0.68, restrict=False, **fastkde_kwargs):
"""
Returns the estimated mode of your mcmc sample assuming it is generated from a continuous distribution
, along with the Highest Posterior Density credible interval containing `cred_mass` fraction of the total
probability. Bandwidth is calculated independently.
HPD is calculated using simple sorted arrays and the mode is calculated by the highest value in a KDE (kernel density estimation) curve.
trace: nd trace or chain from analysis of shape (nsamples, ndims)
credible_interval_mass: The probability mass that the credible interval contains (0.68 == 1sigma)
restrict: If true, estimate the mode only using samples within the credible interval. Slight speed bonus with reduced accuracy.
The accuracy of the mode should not matter if its credible interval is larger than its accuracy.
numPointsPerSigma: how many points per sigma interval to draw the kde with [optional]
returns: mode, hpd_interval, (kde_axes, kde_pdf)
credit: https://stats.stackexchange.com/questions/259319/reliability-of-mode-from-an-mcmc-sample
cite:
"""
original_shape = trace.shape[1:]
if trace.ndim == 1:
trace = trace.reshape(-1, 1)
else:
trace = trace.reshape(trace.shape[0], -1) # ravel last axis (first axis is steps)
if 'numPointsPerSigma' not in fastkde_kwargs:
fastkde_kwargs['numPointsPerSigma'] = 30
hpd_estimates = np.atleast_2d(calc_hpd(trace, credible_interval_mass)) # (dims, interval)
modes = np.zeros(trace.shape[-1])
for i, (param, hpd) in enumerate(zip(trace.T, hpd_estimates)):
if restrict:
x = param[(param < hpd[1]) & (param > hpd[0])]
else:
x = param
if hpd[0] == hpd[1]:
modes[i] = hpd[0]
else:
kde = fastKDE.fastKDE(x, **fastkde_kwargs)
modes[i] = kde.axes[0][np.argmax(kde.pdf)]
modes = modes.reshape(original_shape)
hpd_estimates = hpd_estimates.T.reshape((2,) + original_shape)
if not np.all((modes <= hpd_estimates[1]) & (modes >= hpd_estimates[0])):
warn("Mode estimation has resulted in a mode outside of the HPD region.\n"
"HPD and mode are not reliable!")
return modes, hpd_estimates, (kde.axes, kde.pdf)
def add_hpd_lines_corner_plot(corner_axes, param_label_list, mode, hpd, param_name):
i = param_label_list.index(param_name)
ax = corner_axes[i, i]
ax.axvline(mode, color='k', linestyle='-')
ax.axvline(hpd[0], color='k', linestyle='--')
ax.axvline(hpd[1], color='k', linestyle='--')
ax.set_title('{}$ = {:.2f}^{{+{:.2f}}}_{{-{:.2f}}}$'.format(param_name, mode, mode-hpd[0], hpd[1]-mode))
def corner_plot_hpd(data, labels, cred_mass=0.68, corner_kwargs=None, fastkde_kwargs=None):
if corner_kwargs is None:
corner_kwargs = {}
if fastkde_kwargs is None:
fastkde_kwargs = {}
figure = corner(data, labels=labels, **corner_kwargs)
axes = np.asarray(figure.axes).reshape(len(labels), len(labels))
modes, hpds, _ = find_mode(data, cred_mass, **fastkde_kwargs)
for param, mode, hpd in zip(labels, modes, hpds.T):
add_hpd_lines_corner_plot(axes, labels, mode, hpd, param)
return figure, modes, hpds
if __name__ == '__main__':
from corner import corner
import matplotlib.pyplot as plt
# simulate a model with 3 dimensions
x = np.random.normal(0, 1, size=(10000, 3))
x = np.concatenate([x, np.random.normal(3, 0.2, size=(10000, 3))], axis=0)
x = np.concatenate([x, np.random.normal(2, 0.01, size=(10000, 3))], axis=0)
fig, mode, hpd = corner_plot_hpd(x, list('abc')) # abc are just names of the dimensions
print(mode)
print(hpd)
# simulate a model with 1 dimension
x = np.random.normal(0, 1, size=(10000, 1))
x = np.concatenate([x, np.random.normal(3, 0.2, size=(10000, 1))], axis=0)
x = np.concatenate([x, np.random.normal(2, 0.01, size=(10000, 1))], axis=0)
fig, mode, hpd = corner_plot_hpd(x, ['a'])
print(mode)
print(hpd)
# just get the peak and hpd without plotting:
mode, hpd, _ = find_mode(x, 0.68)
# mode has shape (ndimensions,)
# hpd has shape of ([lowerbound, upperbound], ndimensions)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment