Created
December 14, 2016 21:11
-
-
Save philastrophist/2c4252acbe12084e5b7e00a927f45045 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import pymc3 as pm | |
import numpy as np | |
from theano import tensor as T | |
import matplotlib.pyplot as plt | |
from matplotlib.patches import Ellipse | |
import seaborn as sns | |
""" | |
extended from: https://gist.github.com/AustinRochford/fa24221f09df20071c06 | |
uncertainties to come | |
""" | |
def test_correlation(values, uncertainties=None, rhoxys=0, nsteps=20000, nburnin=2000, nthin=2): | |
nmeas, ndim = values.shape | |
with pm.Model() as model: | |
sigma = pm.Lognormal('sigma', np.zeros(ndim), np.ones(ndim), shape=ndim) | |
nu = pm.Uniform('nu', 0, 5) | |
corr_coeffs = pm.LKJCorr('corr_coeffs', nu, ndim) | |
n_elem = int(ndim * (ndim - 1) / 2) | |
tri_index = np.zeros([ndim, ndim], dtype=int) | |
tri_index[np.triu_indices(ndim, k=1)] = np.arange(n_elem) | |
tri_index[np.triu_indices(ndim, k=1)[::-1]] = np.arange(n_elem) | |
corr = pm.Deterministic('corr', T.fill_diagonal(corr_coeffs[tri_index], 1)) | |
sigma_diag = pm.Deterministic('sigma_mat', T.nlinalg.diag(sigma)) | |
cov = pm.Deterministic('cov', T.nlinalg.matrix_dot(sigma_diag, corr, sigma_diag)) | |
tau = pm.Deterministic('tau', T.nlinalg.matrix_inverse(cov)) | |
mu = pm.MvNormal('mu', 0, tau, shape=ndim) | |
if uncertainties is not None: | |
raise NotImplementedError("Soon...") | |
step = pm.Metropolis() | |
return pm.sample(nsteps, step)[nburnin::nthin] | |
def plot_correlation_2d(mu, covariance, correlation, data=None, ax=None, color='gray', dims=(0, 1), dim_names=None, text=True): | |
if dim_names is None: | |
dim_names = ['x', 'y'] | |
mu = mu[dims,] | |
covariance = covariance[(dims[0], dims[0], dims[1], dims[1]), (dims[0], dims[1], dims[0], dims[1])].reshape(2,2) | |
correlation = correlation[(dims[0], dims[0], dims[1], dims[1]), (dims[0], dims[1], dims[0], dims[1])].reshape(2,2) | |
var, U = np.linalg.eig(covariance) | |
angle = 180. / np.pi * np.arccos(np.abs(U[0, 0])) | |
if ax is None: | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
w, h = 2 * np.sqrt(5.991 * var[0]), 2 * np.sqrt(5.991 * var[1]) | |
angle = -angle if dims[0] > dims[1] else angle | |
e = Ellipse(mu, w, h, angle=-angle) | |
e.set_alpha(0.5) | |
e.set_facecolor(color) | |
ax.add_artist(e) | |
ax.plot(*mu, color=color, linestyle='none', marker='+', markersize=20, mew=4) | |
if text: | |
ax.text(1, 1, r'$\rho={:.2f}$'.format(correlation[0, 1]), transform=ax.transAxes, ha='right', va='top') | |
if data is not None: | |
data = data[:, dims] | |
ax.scatter(data[:, 0], data[:, 1], c='k', alpha=0.5); | |
dist = max([w, h]) | |
ax.set_ylim([mu[1]-dist, mu[1]+dist]) | |
ax.set_xlim([mu[0]-dist, mu[0]+dist]) | |
ax.set_xlabel(dim_names[0]) | |
ax.set_ylabel(dim_names[1]) | |
return ax | |
def plot_correlation(mu, covariance, correlation, data=None, color='gray', dim_names=None, fig=None, text=True): | |
mu, covariance = map(np.asarray, [mu, covariance]) | |
n = covariance.shape[0] | |
if dim_names is None: | |
dim_names = ['dim_{}'.format(i) for i in xrange(n)] | |
if fig is None: | |
fig, axes = plt.subplots(nrows=n, ncols=n) | |
axes = axes.ravel() | |
else: | |
axes = fig.axes | |
pairs = [(i, j) for i in xrange(n) for j in xrange(n)] | |
for pair, ax in zip(pairs, axes): | |
plot_correlation_2d(mu, covariance, correlation, data, ax, color, pair, (dim_names[pair[0]], dim_names[pair[1]]), text=text) | |
plt.tight_layout() | |
return fig | |
if __name__ == '__main__': | |
from scipy.stats import multivariate_normal | |
np.random.seed(0) | |
mu = [1.]*5 | |
corr = np.eye(5) # initialise correlation matrix | |
corr[0, 1] = 0.5 # assign rhos of xy = 0.5 and zy = -0.3, leaving xz = 0 | |
corr[1, 0] = corr[0, 1] | |
corr[2, 1] = -0.3 | |
corr[1, 2] = corr[2, 1] | |
stds = [1]*5 # standard deviations for x, y, z | |
D = np.diag(stds) | |
cov = np.dot(np.dot(D, corr), D) | |
data = multivariate_normal.rvs(mu, cov, size=37) # generate test data | |
errs = np.ones_like(data) * 0.0001 | |
fig = plot_correlation(mu, cov, corr, data, text=False) # plot truth values | |
trace = test_correlation(data, nsteps=20000, nburnin=10000, nthin=1) # run! | |
pm.traceplot(trace, varnames=['corr', 'mu']) | |
mu = np.median(trace['mu'], axis=0) | |
cov = np.median(trace['cov'], axis=0) | |
corr = np.median(trace['corr'], axis=0) | |
std = np.median(trace['sigma'], axis=0) | |
plot_correlation(mu, cov, corr, color='blue', fig=fig) | |
print pm.summary(trace, varnames=['corr_coeffs', 'mu', 'sigma']) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment