Skip to content

Instantly share code, notes, and snippets.

@philastrophist
Created December 14, 2016 21:11
Show Gist options
  • Save philastrophist/2c4252acbe12084e5b7e00a927f45045 to your computer and use it in GitHub Desktop.
Save philastrophist/2c4252acbe12084e5b7e00a927f45045 to your computer and use it in GitHub Desktop.
from __future__ import division
import pymc3 as pm
import numpy as np
from theano import tensor as T
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
import seaborn as sns
"""
extended from: https://gist.github.com/AustinRochford/fa24221f09df20071c06
uncertainties to come
"""
def test_correlation(values, uncertainties=None, rhoxys=0, nsteps=20000, nburnin=2000, nthin=2):
nmeas, ndim = values.shape
with pm.Model() as model:
sigma = pm.Lognormal('sigma', np.zeros(ndim), np.ones(ndim), shape=ndim)
nu = pm.Uniform('nu', 0, 5)
corr_coeffs = pm.LKJCorr('corr_coeffs', nu, ndim)
n_elem = int(ndim * (ndim - 1) / 2)
tri_index = np.zeros([ndim, ndim], dtype=int)
tri_index[np.triu_indices(ndim, k=1)] = np.arange(n_elem)
tri_index[np.triu_indices(ndim, k=1)[::-1]] = np.arange(n_elem)
corr = pm.Deterministic('corr', T.fill_diagonal(corr_coeffs[tri_index], 1))
sigma_diag = pm.Deterministic('sigma_mat', T.nlinalg.diag(sigma))
cov = pm.Deterministic('cov', T.nlinalg.matrix_dot(sigma_diag, corr, sigma_diag))
tau = pm.Deterministic('tau', T.nlinalg.matrix_inverse(cov))
mu = pm.MvNormal('mu', 0, tau, shape=ndim)
if uncertainties is not None:
raise NotImplementedError("Soon...")
step = pm.Metropolis()
return pm.sample(nsteps, step)[nburnin::nthin]
def plot_correlation_2d(mu, covariance, correlation, data=None, ax=None, color='gray', dims=(0, 1), dim_names=None, text=True):
if dim_names is None:
dim_names = ['x', 'y']
mu = mu[dims,]
covariance = covariance[(dims[0], dims[0], dims[1], dims[1]), (dims[0], dims[1], dims[0], dims[1])].reshape(2,2)
correlation = correlation[(dims[0], dims[0], dims[1], dims[1]), (dims[0], dims[1], dims[0], dims[1])].reshape(2,2)
var, U = np.linalg.eig(covariance)
angle = 180. / np.pi * np.arccos(np.abs(U[0, 0]))
if ax is None:
fig, ax = plt.subplots(figsize=(8, 6))
w, h = 2 * np.sqrt(5.991 * var[0]), 2 * np.sqrt(5.991 * var[1])
angle = -angle if dims[0] > dims[1] else angle
e = Ellipse(mu, w, h, angle=-angle)
e.set_alpha(0.5)
e.set_facecolor(color)
ax.add_artist(e)
ax.plot(*mu, color=color, linestyle='none', marker='+', markersize=20, mew=4)
if text:
ax.text(1, 1, r'$\rho={:.2f}$'.format(correlation[0, 1]), transform=ax.transAxes, ha='right', va='top')
if data is not None:
data = data[:, dims]
ax.scatter(data[:, 0], data[:, 1], c='k', alpha=0.5);
dist = max([w, h])
ax.set_ylim([mu[1]-dist, mu[1]+dist])
ax.set_xlim([mu[0]-dist, mu[0]+dist])
ax.set_xlabel(dim_names[0])
ax.set_ylabel(dim_names[1])
return ax
def plot_correlation(mu, covariance, correlation, data=None, color='gray', dim_names=None, fig=None, text=True):
mu, covariance = map(np.asarray, [mu, covariance])
n = covariance.shape[0]
if dim_names is None:
dim_names = ['dim_{}'.format(i) for i in xrange(n)]
if fig is None:
fig, axes = plt.subplots(nrows=n, ncols=n)
axes = axes.ravel()
else:
axes = fig.axes
pairs = [(i, j) for i in xrange(n) for j in xrange(n)]
for pair, ax in zip(pairs, axes):
plot_correlation_2d(mu, covariance, correlation, data, ax, color, pair, (dim_names[pair[0]], dim_names[pair[1]]), text=text)
plt.tight_layout()
return fig
if __name__ == '__main__':
from scipy.stats import multivariate_normal
np.random.seed(0)
mu = [1.]*5
corr = np.eye(5) # initialise correlation matrix
corr[0, 1] = 0.5 # assign rhos of xy = 0.5 and zy = -0.3, leaving xz = 0
corr[1, 0] = corr[0, 1]
corr[2, 1] = -0.3
corr[1, 2] = corr[2, 1]
stds = [1]*5 # standard deviations for x, y, z
D = np.diag(stds)
cov = np.dot(np.dot(D, corr), D)
data = multivariate_normal.rvs(mu, cov, size=37) # generate test data
errs = np.ones_like(data) * 0.0001
fig = plot_correlation(mu, cov, corr, data, text=False) # plot truth values
trace = test_correlation(data, nsteps=20000, nburnin=10000, nthin=1) # run!
pm.traceplot(trace, varnames=['corr', 'mu'])
mu = np.median(trace['mu'], axis=0)
cov = np.median(trace['cov'], axis=0)
corr = np.median(trace['corr'], axis=0)
std = np.median(trace['sigma'], axis=0)
plot_correlation(mu, cov, corr, color='blue', fig=fig)
print pm.summary(trace, varnames=['corr_coeffs', 'mu', 'sigma'])
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment