Created
June 15, 2019 05:19
-
-
Save sidravi1/a7965d57c63e71f9b9ff47098cd774df to your computer and use it in GitHub Desktop.
HMC animation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import autograd.numpy as np | |
import scipy.stats as st | |
import matplotlib.pyplot as plt | |
import matplotlib.animation as animation | |
import matplotlib as mpl | |
import seaborn as sns | |
from minimc.minimc.minimc_slow import hamiltonian_monte_carlo as hmc_slow | |
from minimc.minimc import neg_log_normal, mixture | |
FIGSIZE = (10, 7) | |
mixture_params = [(0, 0.1), (0.5, 0.2), (-0.5, 0.2)] | |
n_mix = len(mixture_params) | |
mixture_norm = [neg_log_normal(param[0], param[1]) for param in mixture_params] | |
p_mix = [1/n_mix] * n_mix | |
mixture_logp = mixture(mixture_norm, p_mix) | |
samples, positions, momentums, accepted = hmc_slow(50, mixture_logp, | |
initial_position=0., | |
path_len=1.0, | |
step_size=0.01) | |
pos_vec, mom_vec = np.hstack(positions), np.hstack(momentums) | |
np.random.seed(100) | |
def init(): | |
ax.set_ylim(-3.5, 3.5) | |
ax.set_xlim(-1.0, 1.0) | |
line.set_data([], []) | |
star.set_data([], []) | |
sample_points.set_data([], []) | |
selected_pos_mom.set_data([], []) | |
return line, star | |
def run(data): | |
# update the data | |
p, m, s, pm = data | |
line.set_data(p, m) | |
sample_points.set_data(s, [-3.3]*len(s)) | |
if pm.shape[0] > 0: | |
selected_pos_mom.set_data(pm[:, 0], pm[:, 1]) | |
if len(p) > 0: | |
star.set_data([p[-1], m[-1]]) | |
else: | |
star.set_data(p, m) | |
return line, star | |
def data_gen(): | |
cnt = 0 | |
i = 0 | |
sample_selected = [] | |
sample_pos_mom = [] | |
while cnt < mom_vec.shape[0]: | |
low_val = np.max(cnt-100, 0) | |
if (cnt > 0) and (pos_vec[cnt] == pos_vec[cnt-1]): | |
sample_selected.append(samples[i]) | |
sample_pos_mom.append([pos_vec[cnt], mom_vec[cnt]]) | |
i += 1 | |
cnt += 1 | |
yield pos_vec[low_val:cnt], mom_vec[low_val:cnt],\ | |
sample_selected, np.array(sample_pos_mom) | |
with plt.style.context('Solarize_Light2'): | |
print(mpl.__version__) | |
fig, ax = plt.subplots() | |
mus = [x[0] for x in mixture_params] | |
sds = [x[1] for x in mixture_params] | |
ax.set_xlabel("Position") | |
ax.set_ylabel("Momentum") | |
ax.grid(ls="--") | |
actual_samples = st.norm(mus, sds).rvs([10000, n_mix]) | |
ax2 = ax.twinx() | |
ax2.grid(False) | |
sns.kdeplot(actual_samples.ravel(), shade=True, ax=ax2) | |
line, = ax.plot([], [], lw=2, color='firebrick') | |
star, = ax.plot([], [], "*r") | |
sample_points, = ax.plot([], [], ".k") | |
selected_pos_mom, = ax.plot([], [], "x", color='orange') | |
ani = animation.FuncAnimation(fig, run, data_gen, blit=False, interval=1, | |
repeat=False, init_func=init, save_count=mom_vec.shape[0]*1.5) | |
#ani.save('animation.gif', writer='imagemagick', fps=60) | |
#ani.save('anim_hmc.mp4', 'ffmpeg', fps=40) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment