Last active
March 5, 2016 10:50
-
-
Save taku-y/0c7082a2ab7eadacb2d6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from pymc3 import Model, MvNormal | |
import theano | |
import theano.tensor as T | |
def run_check_logdet(): | |
print('alpha0=1e4') | |
check_logdet(alpha0=1e4) | |
print('alpha0=1e3') | |
check_logdet(alpha0=1e3) | |
print('alpha0=1e2') | |
check_logdet(alpha0=1e2) | |
def check_logdet(alpha0=1e4, n_groups=100): | |
n_sensors = 200 | |
n_vertices = 100 | |
n_timepoints = 2 | |
n_groups = n_groups | |
cov = np.eye(n_sensors) | |
bs = np.random.randn(n_timepoints, n_sensors) | |
gs = [np.random.randn(n_sensors, n_vertices) for _ in range(n_groups)] | |
gags = [(1.0 / alpha0) * g.dot(g.T) for g in gs] | |
gag = np.sum(np.stack(gags), axis=0) | |
bcov = cov + gag | |
prec = np.linalg.inv(bcov) | |
eigs, _ = np.linalg.eig(prec) | |
with Model() as model1: | |
# Likelihood for observations | |
MvNormal('l', mu=0.0, tau=T.as_tensor(prec), observed=bs) | |
with Model() as model2: | |
# Likelihood for observations | |
MyMvNormal('l', mu=0.0, tau=T.as_tensor(prec), observed=bs) | |
print('logp of with log(det()) = {}'.format(model1.logp())) | |
print('logp of with logabsdet() = {}'.format(model2.logp())) | |
from theano.gof import Op, Apply | |
# The code is adopted from https://github.com/Theano/Theano/pull/3959 | |
class LogAbsDet(Op): | |
"""Computes the logarithm of absolute determinant of a square | |
matrix M, log(abs(det(M))), on CPU. Avoids det(M) overflow/ | |
underflow. | |
TODO: add GPU code! | |
""" | |
def make_node(self, x): | |
x = theano.tensor.as_tensor_variable(x) | |
o = theano.tensor.scalar(dtype=x.dtype) | |
return Apply(self, [x], [o]) | |
def perform(self, node, inputs, outputs): | |
try: | |
(x,) = inputs | |
(z,) = outputs | |
s = np.linalg.svd(x, compute_uv=False) | |
log_abs_det = np.sum(np.log(np.abs(s))) | |
z[0] = np.asarray(log_abs_det, dtype=x.dtype) | |
except Exception: | |
print('Failed to compute logabsdet of {}.'.format(x)) | |
raise | |
def grad(self, inputs, g_outputs): | |
gz, = g_outputs | |
x, = inputs | |
return [gz * T.nlinalg.matrix_inverse(x).T] | |
def __str__(self): | |
return "LogAbsDet" | |
logabsdet = LogAbsDet() | |
from scipy import stats | |
from pymc3.distributions.distribution import Continuous, draw_values, generate_samples | |
class MyMvNormal(Continuous): | |
r""" | |
Multivariate normal log-likelihood. | |
.. math:: | |
f(x \mid \pi, T) = | |
\frac{|T|^{1/2}}{(2\pi)^{1/2}} | |
\exp\left\{ -\frac{1}{2} (x-\mu)^{\prime} T (x-\mu) \right\} | |
======== ========================== | |
Support :math:`x \in \mathbb{R}^k` | |
Mean :math:`\mu` | |
Variance :math:`T^{-1}` | |
======== ========================== | |
Parameters | |
---------- | |
mu : array | |
Vector of means. | |
tau : array | |
Precision matrix. | |
""" | |
def __init__(self, mu, tau, *args, **kwargs): | |
super(MyMvNormal, self).__init__(*args, **kwargs) | |
self.mean = self.median = self.mode = self.mu = mu | |
self.tau = tau | |
def random(self, point=None, size=None): | |
mu, tau = draw_values([self.mu, self.tau], point=point) | |
def _random(mean, cov, size=None): | |
# FIXME: cov here is actually precision? | |
return stats.multivariate_normal.rvs( | |
mean, cov, None if size == mean.shape else size) | |
samples = generate_samples(_random, | |
mean=mu, cov=tau, | |
dist_shape=self.shape, | |
broadcast_shape=mu.shape, | |
size=size) | |
return samples | |
def logp(self, value): | |
mu = self.mu | |
tau = self.tau | |
delta = value - mu | |
k = tau.shape[0] | |
# result = k * T.log(2 * np.pi) + T.log(1./det(tau)) | |
result = k * T.log(2 * np.pi) - logabsdet(tau) | |
result += (delta.dot(tau) * delta).sum(axis=delta.ndim - 1) | |
return -1/2. * result | |
#In [1]: import sys; sys.path.insert(0, '/Users/taku-y/git/github/pymc3') | |
# | |
#In [2]: import test | |
# | |
#In [3]: test.run_check_logdet() | |
#alpha0=1e4 | |
#logp of with log(det()) = -613.857479372 | |
#logp of with logabsdet() = -613.857479372 | |
#alpha0=1e3 | |
#logp of with log(det()) = -862.731539233 | |
#logp of with logabsdet() = -862.731539233 | |
#alpha0=1e2 | |
#logp of with log(det()) = -inf | |
#logp of with logabsdet() = -1290.73899232 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment