Created
October 21, 2017 21:44
-
-
Save mattjj/f7c7334021f3132dbbab5714aa03001a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import numpy.random as npr | |
import matplotlib.pyplot as plt | |
from matplotlib.widgets import Slider, RadioButtons | |
### mniw utility functions | |
def sample_mniw(param, num_samples, rng=npr): | |
S, M, K, nu = param | |
nu = np.atleast_1d(nu) | |
sigma = sample_invwishart(S, nu[..., 0], num_samples, rng) | |
A = sample_matrix_normal(sigma, M, K, num_samples, rng) | |
return A, sigma | |
def sample_invwishart(S, nu, num_samples, rng=npr): | |
assert S.ndim >= 2 | |
n = S.shape[-1] | |
chol = np.expand_dims(np.linalg.cholesky(S), -3) | |
chisq = rng.chisquare(np.reshape(nu, np.shape(S)[:-2] + (1, 1)) | |
- np.ones(num_samples)[:,None] * np.arange(n)) | |
normal = rng.normal(size=S.shape[:-2] + (num_samples,) + S.shape[-2:]) | |
R = qr(make_diag(np.sqrt(chisq)) + triu(normal, k=1)) | |
X = T(np.linalg.solve(T(R), T(chol))) | |
return np.matmul(X, T(X)) | |
def sample_matrix_normal(S, M, K, num_samples, rng=npr): | |
G = rng.normal(size=M.shape[:-2] + (num_samples,) + M.shape[-2:]) | |
expnd = lambda X: np.expand_dims(X, -3) | |
return expnd(M) + np.matmul(np.matmul(np.linalg.cholesky(S), G), | |
T(np.linalg.cholesky(expnd(K)))) | |
### numerical util | |
def T(X): return np.swapaxes(X, axis1=-1, axis2=-2) if np.ndim(X) > 1 else X | |
def make_diag(x): | |
shape = np.shape(x) | |
return np.expand_dims(x, -1) * np.ones(shape[:-1] + (1, 1)) * np.eye(shape[-1]) | |
def triu(a, k=0): | |
shape = a.shape | |
ones = np.ones(shape[:-2] + (1, 1)) | |
_ones = ones * np.triu(np.ones((shape[-1], shape[-1])), k=k) | |
return _ones * a | |
def qr(a): | |
shape = (-1,) + a.shape[-2:] | |
out = np.stack([np.linalg.qr(x, 'r') for x in np.reshape(a, shape)]) | |
return np.reshape(out, np.shape(a)) | |
### viz | |
def sample_eigenvalues(N, alpha, beta, gamma): | |
M = 0.8 * np.eye(N+1)[:N] | |
S = alpha * np.eye(N) | |
K = beta * np.eye(N + 1) | |
nu = (N + N * gamma) | |
param = (S, M, K, nu) | |
rng = np.random.RandomState(0) | |
Ab, Sigma = sample_mniw(param, 1000 // N, rng) | |
A = Ab[..., :-1] | |
e = np.linalg.eigvals(A) | |
return np.ravel(e) # ravel because of symmetry | |
if __name__ == '__main__': | |
fig, ax = plt.subplots() | |
plt.subplots_adjust(left=0.25, bottom=0.25) | |
ax.axis('equal') | |
init_alpha, init_beta, init_gamma = 1., 1., 10. | |
# draw initial data | |
e = sample_eigenvalues(10, init_alpha, init_beta, init_gamma) | |
l2, = ax.plot(np.real(e), np.imag(e), 'y.', alpha=0.2) | |
e = sample_eigenvalues(2, init_alpha, init_beta, init_gamma) | |
l1, = ax.plot(np.real(e), np.imag(e), 'b.', alpha=0.2) | |
# draw a circle | |
t = np.linspace(0, 2*np.pi, 1000, endpoint=True) | |
ax.plot(np.cos(t), np.sin(t), 'r-') | |
# create sliders | |
axcolor = 'lightgoldenrodyellow' | |
axalpha = plt.axes([0.25, 0.15, 0.65, 0.03], facecolor=axcolor) | |
axbeta = plt.axes([0.25, 0.10, 0.65, 0.03], facecolor=axcolor) | |
axgamma = plt.axes([0.25, 0.05, 0.65, 0.03], facecolor=axcolor) | |
salpha = Slider(axalpha, r'$\alpha$', 0.1, 10., valinit=init_alpha) | |
sbeta = Slider(axbeta, r'$\beta$', 0.1, 10., valinit=init_beta) | |
sgamma = Slider(axgamma, r'$\gamma$', 0.1, 30., valinit=init_gamma) | |
def update(_): | |
if gamma_equals_alpha: | |
sgamma.eventson = False | |
sgamma.set_val(salpha.val) | |
sgamma.eventson = True | |
e = sample_eigenvalues(2, salpha.val, sbeta.val, sgamma.val) | |
l1.set_data(np.real(e), np.imag(e)) | |
e = sample_eigenvalues(10, salpha.val, sbeta.val, sgamma.val) | |
l2.set_data(np.real(e), np.imag(e)) | |
fig.canvas.draw_idle() | |
salpha.on_changed(update) | |
sbeta.on_changed(update) | |
sgamma.on_changed(update) | |
# create radio buttons | |
rax = plt.axes([0.025, 0.5, 0.15, 0.15], facecolor=axcolor) | |
radio = RadioButtons(rax, (r'$\gamma$ free', r'$\gamma = \alpha$'), active=0) | |
gamma_equals_alpha = False | |
def set_gamma_equals_alpha(label): | |
global gamma_equals_alpha | |
dct = {r'$\gamma$ free': False, r'$\gamma = \alpha$': True} | |
gamma_equals_alpha = dct[label] | |
update(None) | |
radio.on_clicked(set_gamma_equals_alpha) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment