Skip to content

Instantly share code, notes, and snippets.

@brandonwillard
Created May 18, 2016 20:36
Show Gist options
  • Save brandonwillard/ef4db928eb4c18352eb04e32f3d47002 to your computer and use it in GitHub Desktop.
Save brandonwillard/ef4db928eb4c18352eb04e32f3d47002 to your computer and use it in GitHub Desktop.
Somewhat fixed MvNormal implementation
from scipy import stats
from theano.tensor.nlinalg import det, matrix_inverse, trace, eigh
from pymc3 import transforms
from pymc3.distributions.distribution import Continuous, Discrete, draw_values, generate_samples
from pymc3.distributions.special import gammaln, multigammaln
from pymc3.distributions.dist_math import bound, logpow, factln
class MvNormal(Continuous):
r"""
This MvNormal can handle tensor arguments (to some degree). Also, the sampling routine uses the correct covariance,
but it's costly.
"""
def __init__(self, mu, tau, *args, **kwargs):
super(MvNormal, self).__init__(*args, **kwargs)
self.mean = self.median = self.mode = self.mu = mu
self.tau = tau
def random(self, point=None, size=None):
mu, tau = draw_values([self.mu, self.tau], point=point)
def _random(mean, tau, size=None):
supp_dim = mean.shape[-1]
mus_collapsed = mean.reshape((-1, supp_dim))
taus_collapsed = tau.reshape((-1, supp_dim, supp_dim))
# FIXME: do something smarter about tau/cov
covs_collapsed = np.apply_over_axes(lambda x,y: np.linalg.inv(x), taus_collapsed, 0)
from functools import partial
mvrvs = partial(stats.multivariate_normal.rvs, size=1)
res = map(mvrvs, mus_collapsed, covs_collapsed)
# FIXME: this is a hack; the PyMC sampling framework
# will incorrectly set `size == Distribution.shape` when a single
# sample is requested, implying that we want
# `Distribution.shape`-many samples of a
# `Distribution.shape` sized object: too many! That's why
# we're ignoring `size` right now and only ever asking
# for a single sample.
return np.asarray(res).reshape(mean.shape)
samples = generate_samples(_random,
mean=mu, tau=tau,
dist_shape=self.shape,
broadcast_shape=mu.shape,
size=size)
return samples
def logp(self, value):
mu = T.as_tensor_variable(self.mu)
tau = T.as_tensor_variable(self.tau)
reps_shape_T = tau.shape[:-2]
reps_shape_prod = T.prod(reps_shape_T, keepdims=True)
dist_shape_T = mu.shape[-1:]
# collapse reps dimensions
flat_supp_shape = T.concatenate((reps_shape_prod, dist_shape_T))
mus_collapsed = mu.reshape(flat_supp_shape, ndim=2)
taus_collapsed = tau.reshape(T.concatenate((reps_shape_prod,
dist_shape_T, dist_shape_T)), ndim=3)
# force value to conform to reps_shape
value_reshape = T.ones_like(mu) * value
values_collapsed = value_reshape.reshape(flat_supp_shape, ndim=2)
def single_logl(_mu, _tau, _value, k):
delta = _value - _mu
result = k * T.log(2 * np.pi) + T.log(det(_tau))
result += T.square(delta.dot(_tau)).sum(axis=-1)
return -result/2
from theano import scan
res, _ = scan(fn=single_logl
, sequences=[mus_collapsed, taus_collapsed, values_collapsed]
, non_sequences=[dist_shape_T]
, strict=True
)
return res.sum()
@AngelBerihuete
Copy link

AngelBerihuete commented Jul 20, 2016

Hi @brandonwillard
The problem above was fixed, but I have a doubt about your function. Coding a simple model like

import numpy as np
import pymc3 as pm
from mvnormal_extension  import MvNormal

with pm.Model() as model:
    var_x = MvNormal('var_x', mu = 3*np.zeros(2), tau = np.diag(np.ones(2)),  shape=2)
    trace = pm.sample(100)

pm.summary(trace)

I obtain these results

var_x:

  Mean             SD               MC Error         95% HPD interval
  -------------------------------------------------------------------

  0.220            1.161            0.116            [-1.897, 2.245]
  0.165            1.024            0.102            [-2.626, 1.948]

  Posterior quantiles:
  2.5            25             50             75             97.5
  |--------------|==============|==============|--------------|

  -1.897         -0.761         0.486          1.112          2.245
  -2.295         -0.426         0.178          0.681          2.634

which, of course, it is not correct because Mean should be near to (3,3).

Any clue why is this happening?
Thanks
Angel

@AngelBerihuete
Copy link

Forget the last message.

Of course the Means are near to (0,0) if MvNormal has

mu = 3*np.zeros(2)

It was my fault, I was thinking on

mu = 3*np.ones(2)

@lezorich
Copy link

Hi @AngelBerihuete, how did you fix the DensityDist alternative? I'm getting the same exception as you.

Thanks!
Lukas.

@AngelBerihuete
Copy link

Hi @lezorich, I'm so sorry for this huge delay.
Finally I took the normal-extension.py file coded by @brandonwillard.
Did you fix DensityDist alternative?
Angel

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment