Last active
September 12, 2017 02:58
-
-
Save mattjj/4864a3ef55e21a19a4784b30a660f83b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import autograd.numpy as np | |
import autograd.numpy.random as npr | |
from autograd import jacobian, make_hvp | |
import operator as op | |
npr.seed(0) | |
## direct computation | |
def prod(lst): return reduce(op.mul, lst) | |
def lift(mat, pair, ndim): | |
shape = iter(mat.shape) | |
lifted_shape = [next(shape) if i in pair else 1 for i in range(ndim)] | |
return np.reshape(mat, lifted_shape) | |
# from https://stackoverflow.com/a/5360442 | |
def all_pairs(lst): | |
if len(lst) < 2: | |
yield lst | |
else: | |
first = lst[0] | |
for i in range(1, len(lst)): | |
pair = (first, lst[i]) | |
for rest in all_pairs(lst[1:i] + lst[i+1:]): | |
yield [pair] + rest | |
# see https://en.wikipedia.org/wiki/Isserlis%27_theorem | |
def moment(sigma, k): | |
N = sigma.shape[0] | |
if k % 2: | |
return np.zeros((N,) * k) | |
return sum(prod(lift(sigma, pair, k) for pair in group) | |
for group in all_pairs(range(k))) | |
## autograd computation | |
def logZ(neghalfJ, h): | |
J = -2 * neghalfJ | |
L = np.linalg.cholesky(J) | |
return 0.5 * np.dot(h, np.linalg.solve(J, h)) - np.sum(np.log(np.diag(L))) | |
def gaussian_mgf(sigma, mu): | |
neghalfJ, h = -1./2 * np.linalg.inv(sigma), np.linalg.solve(sigma, mu) | |
return lambda t: np.exp(logZ(neghalfJ, t + h) - logZ(neghalfJ, h)) | |
def moment2(sigma, k): | |
n = sigma.shape[0] | |
M = gaussian_mgf(sigma, np.zeros(n)) | |
for _ in range(k): | |
M = jacobian(M) | |
return M(np.zeros(n)) | |
def gaussian_mgf2(sigma, mu): | |
neghalfJ, h = -1./2 * np.linalg.inv(sigma), np.linalg.solve(sigma, mu) | |
return lambda t: np.exp(logZ(neghalfJ + t, h) - logZ(neghalfJ, h)) | |
## quick check | |
rng = npr.RandomState(0) | |
sigma = (lambda X: np.dot(X, X.T))(rng.randn(2, 2)) | |
print np.allclose(sigma, moment(sigma, 2)) | |
print np.allclose(moment(sigma, 2), moment2(sigma, 2)) | |
print np.allclose(moment(sigma, 3), moment2(sigma, 3)) | |
print np.allclose(moment(sigma, 4), moment2(sigma, 4)) | |
hvp, _ = make_hvp(gaussian_mgf2(sigma, np.zeros(2)))(np.zeros((2, 2))) | |
V = npr.randn(2, 2) | |
print np.allclose(hvp(V), np.tensordot(moment(sigma, 4), V, 2)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment