Skip to content

Instantly share code, notes, and snippets.

@mattjj
Created October 21, 2017 21:44
Show Gist options
  • Save mattjj/f7c7334021f3132dbbab5714aa03001a to your computer and use it in GitHub Desktop.
Save mattjj/f7c7334021f3132dbbab5714aa03001a to your computer and use it in GitHub Desktop.
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, RadioButtons
### mniw utility functions
def sample_mniw(param, num_samples, rng=npr):
S, M, K, nu = param
nu = np.atleast_1d(nu)
sigma = sample_invwishart(S, nu[..., 0], num_samples, rng)
A = sample_matrix_normal(sigma, M, K, num_samples, rng)
return A, sigma
def sample_invwishart(S, nu, num_samples, rng=npr):
assert S.ndim >= 2
n = S.shape[-1]
chol = np.expand_dims(np.linalg.cholesky(S), -3)
chisq = rng.chisquare(np.reshape(nu, np.shape(S)[:-2] + (1, 1))
- np.ones(num_samples)[:,None] * np.arange(n))
normal = rng.normal(size=S.shape[:-2] + (num_samples,) + S.shape[-2:])
R = qr(make_diag(np.sqrt(chisq)) + triu(normal, k=1))
X = T(np.linalg.solve(T(R), T(chol)))
return np.matmul(X, T(X))
def sample_matrix_normal(S, M, K, num_samples, rng=npr):
G = rng.normal(size=M.shape[:-2] + (num_samples,) + M.shape[-2:])
expnd = lambda X: np.expand_dims(X, -3)
return expnd(M) + np.matmul(np.matmul(np.linalg.cholesky(S), G),
T(np.linalg.cholesky(expnd(K))))
### numerical util
def T(X): return np.swapaxes(X, axis1=-1, axis2=-2) if np.ndim(X) > 1 else X
def make_diag(x):
shape = np.shape(x)
return np.expand_dims(x, -1) * np.ones(shape[:-1] + (1, 1)) * np.eye(shape[-1])
def triu(a, k=0):
shape = a.shape
ones = np.ones(shape[:-2] + (1, 1))
_ones = ones * np.triu(np.ones((shape[-1], shape[-1])), k=k)
return _ones * a
def qr(a):
shape = (-1,) + a.shape[-2:]
out = np.stack([np.linalg.qr(x, 'r') for x in np.reshape(a, shape)])
return np.reshape(out, np.shape(a))
### viz
def sample_eigenvalues(N, alpha, beta, gamma):
M = 0.8 * np.eye(N+1)[:N]
S = alpha * np.eye(N)
K = beta * np.eye(N + 1)
nu = (N + N * gamma)
param = (S, M, K, nu)
rng = np.random.RandomState(0)
Ab, Sigma = sample_mniw(param, 1000 // N, rng)
A = Ab[..., :-1]
e = np.linalg.eigvals(A)
return np.ravel(e) # ravel because of symmetry
if __name__ == '__main__':
fig, ax = plt.subplots()
plt.subplots_adjust(left=0.25, bottom=0.25)
ax.axis('equal')
init_alpha, init_beta, init_gamma = 1., 1., 10.
# draw initial data
e = sample_eigenvalues(10, init_alpha, init_beta, init_gamma)
l2, = ax.plot(np.real(e), np.imag(e), 'y.', alpha=0.2)
e = sample_eigenvalues(2, init_alpha, init_beta, init_gamma)
l1, = ax.plot(np.real(e), np.imag(e), 'b.', alpha=0.2)
# draw a circle
t = np.linspace(0, 2*np.pi, 1000, endpoint=True)
ax.plot(np.cos(t), np.sin(t), 'r-')
# create sliders
axcolor = 'lightgoldenrodyellow'
axalpha = plt.axes([0.25, 0.15, 0.65, 0.03], facecolor=axcolor)
axbeta = plt.axes([0.25, 0.10, 0.65, 0.03], facecolor=axcolor)
axgamma = plt.axes([0.25, 0.05, 0.65, 0.03], facecolor=axcolor)
salpha = Slider(axalpha, r'$\alpha$', 0.1, 10., valinit=init_alpha)
sbeta = Slider(axbeta, r'$\beta$', 0.1, 10., valinit=init_beta)
sgamma = Slider(axgamma, r'$\gamma$', 0.1, 30., valinit=init_gamma)
def update(_):
if gamma_equals_alpha:
sgamma.eventson = False
sgamma.set_val(salpha.val)
sgamma.eventson = True
e = sample_eigenvalues(2, salpha.val, sbeta.val, sgamma.val)
l1.set_data(np.real(e), np.imag(e))
e = sample_eigenvalues(10, salpha.val, sbeta.val, sgamma.val)
l2.set_data(np.real(e), np.imag(e))
fig.canvas.draw_idle()
salpha.on_changed(update)
sbeta.on_changed(update)
sgamma.on_changed(update)
# create radio buttons
rax = plt.axes([0.025, 0.5, 0.15, 0.15], facecolor=axcolor)
radio = RadioButtons(rax, (r'$\gamma$ free', r'$\gamma = \alpha$'), active=0)
gamma_equals_alpha = False
def set_gamma_equals_alpha(label):
global gamma_equals_alpha
dct = {r'$\gamma$ free': False, r'$\gamma = \alpha$': True}
gamma_equals_alpha = dct[label]
update(None)
radio.on_clicked(set_gamma_equals_alpha)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment