Random Sampling
Last active
May 20, 2017 17:46
-
-
Save qxj/66d42234f58d519b2511ed2892cf7411 to your computer and use it in GitHub Desktop.
Random Sampling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
http://www.nehalemlabs.net/prototype/blog/2014/02/24/an-introduction-to-the-metropolis-method-with-python/ | |
''' | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.mlab as mlab | |
def q(x, y): | |
g1 = mlab.bivariate_normal(x, y, 1.0, 1.0, -1, -1, -0.8) | |
g2 = mlab.bivariate_normal(x, y, 1.5, 0.8, 1, 2, 0.6) | |
return 0.6*g1+28.4*g2/(0.6+28.4) | |
'''Metropolis Hastings''' | |
N = 100000 | |
s = 10 | |
r = np.zeros(2) | |
p = q(r[0], r[1]) | |
print p | |
samples = [] | |
for i in xrange(N): | |
rn = r + np.random.normal(size=2) | |
pn = q(rn[0], rn[1]) | |
if pn >= p: | |
p = pn | |
r = rn | |
else: | |
u = np.random.rand() | |
if u < pn/p: | |
p = pn | |
r = rn | |
if i % s == 0: | |
samples.append(r) | |
samples = np.array(samples) | |
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5, s=1) | |
'''Plot target''' | |
dx = 0.01 | |
x = np.arange(np.min(samples), np.max(samples), dx) | |
y = np.arange(np.min(samples), np.max(samples), dx) | |
X, Y = np.meshgrid(x, y) | |
Z = q(X, Y) | |
CS = plt.contour(X, Y, Z, 10) | |
plt.clabel(CS, inline=1, fontsize=10) | |
plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
def mcmc(p=np.array([.1, .2, .3, .4]), n=10, | |
converge_threshold=10, is_mh=True): | |
# 随意生成一个概率转移矩阵,这里直接用了给定的概率分布 | |
Q = np.array([p for _ in range(len(p))], dtype=np.float32) | |
x0 = [np.random.randint(len(p)) for _ in range(n)] | |
converge_num = 0 | |
sample_count = 0 | |
while True: | |
idx = np.random.randint(n) | |
y = np.argmax(np.random.multinomial(1, Q[x0[idx]])) | |
sample_count += 1 | |
alpha = 0 # 计算接收率 alpha=p[j]*Q[j][i] | |
if is_mh: | |
alpha = min( | |
[1, (p[y] * Q[y][x0[idx]]) / (p[x0[idx]] * Q[x0[idx]][y])]) | |
else: | |
alpha = p[y] * Q[y][x0[idx]] | |
if np.random.ranf() < alpha: | |
if y == x0[idx]: # 状态未变更 | |
if converge_num >= converge_threshold: | |
# 状态连续多次未变更,收敛返回 | |
print '收敛状态:{}'.format(x0) | |
print '采样计数:{}'.format(sample_count) | |
return x0 | |
else: | |
# 状态未变更但还不稳定,稳定计数增加 | |
converge_num += 1 | |
else: # 有状态变更,修改 | |
x0[idx] = y | |
converge_num = 0 | |
samples = mcmc(n=100) | |
plt.hist(samples, 4) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment