Created
June 14, 2018 18:28
-
-
Save lhk/9d55662e5e163bd40171c0f8e038aec1 to your computer and use it in GitHub Desktop.
multi-dim metropolis hastings in python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
import numpy as np | |
from scipy.stats import multivariate_normal | |
def metropolis_hastings(f, num_samples, burn_in, result_queue=None): | |
# sample normal values as stepsize for the updates | |
# important: g is symmetric, so we don't have to use it in the calculation of alpha below | |
steps = np.random.normal(0, 1, (num_samples, 2)) | |
# with some bookkeeping, I only have to call the pdf of f once per loop iteration | |
# that is initialized here | |
x = np.zeros((2,)) | |
x_next = x + np.random.normal(0, 2, 2) # TODO: you can introduce a different stepsize by scaling these | |
current_prob = f.pdf(x) | |
next_prob = f.pdf(x_next) | |
x_chosen = np.zeros((num_samples,2)) | |
for i in range(num_samples): | |
# to account for cases where the pdf is 0 | |
# it would be good to avoid them by having a sensible starting point for x | |
# they can also occur if the stepsize is so huge that our samples run out of domain | |
# so this is a security measure | |
if current_prob == 0: | |
# we always accept the next sample | |
alpha = 1 | |
elif next_prob == 0: | |
# we never accept the next sample accept the next sample | |
alpha = 0 | |
else: | |
# this is the normal MH alpha calculation | |
alpha = next_prob / current_prob | |
if np.random.rand() < alpha: | |
x = x_next | |
current_prob = next_prob | |
x_next = x + steps[i] | |
next_prob = f.pdf(x_next) | |
x_chosen[i] = x | |
x_final = x_chosen[burn_in:] | |
if not result_queue is None: | |
result_queue.put(x_final) | |
return x_final | |
from multiprocessing import Process | |
from multiprocessing import Queue | |
result_queue = Queue(maxsize=10) | |
f = multivariate_normal([5, 5], np.eye(2) * 2) | |
num_samples = 100000 | |
burn_in = 1000 | |
num_processes = 12 | |
processes = [] | |
for num_process in range(num_processes): | |
p = Process(target=metropolis_hastings, args=[f, num_samples, burn_in, result_queue]) | |
p.start() | |
processes.append(p) | |
results = [] | |
for num_process in processes: | |
results.append(result_queue.get()) | |
x_final = np.vstack(results) | |
print("the correct mean is {}, we have {}".format(f.mean, x_final.mean(axis=0))) | |
print("the correct standard deviation is {}, we have {}".format(f.cov, np.cov(x_final.T))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment