Skip to content

Instantly share code, notes, and snippets.

@marmakoide
Created May 11, 2018 14:01
Show Gist options
  • Save marmakoide/6f55ff99f14c896399c460a38f72c99a to your computer and use it in GitHub Desktop.
Save marmakoide/6f55ff99f14c896399c460a38f72c99a to your computer and use it in GitHub Desktop.
Von Mises Fisher distribution in 3d ie. a spherical analog to an isotropic Gaussian distribution.
import numpy
'''
Pick a point uniformly from the unit circle
'''
def circle_uniform_pick(size, out = None):
if out is None:
out = numpy.empty((size, 2))
angle = 2 * numpy.pi * numpy.random.random(size)
out[:,0], out[:,1] = numpy.cos(angle), numpy.sin(angle)
return out
def cross_product_matrix(U):
return numpy.array([[ 0., -U[2], U[1]],
[ U[2], 0., -U[0]],
[-U[1], U[0], 0.]])
class SingularityError(Exception):
def __init__(self):
pass
'''
Von Mises-Fisher distribution, ie. isotropic Gaussian distribution defined over
a sphere.
mu => mean direction
kappa => concentration
Uses numerical tricks described in "Numerically stable sampling of the von
Mises Fisher distribution on S2 (and other tricks)" by Wenzel Jakob
Uses maximum likelyhood estimators described in "Modeling Data using
Directional Distributions" by Inderjit S. Dhillon and Suvrit Sra
'''
class VonMisesFisher3(object):
def __init__(self, mu, kappa):
self.mu = mu
self.kappa = kappa
self.pdf_constant = self.kappa / ((2 * numpy.pi) * (1. - numpy.exp(-2. * self.kappa)))
self.log_pdf_constant = numpy.log(self.pdf_constant)
'''
Generates samples from the distribution
'''
def sample(self, size, out = None):
# Generate the samples for mu=(0, 0, 1)
eta = numpy.random.random(size)
tmp = 1. - (((eta - 1.) / eta) * numpy.exp(-2. * self.kappa))
W = 1. + (numpy.log(eta) + numpy.log(tmp)) / self.kappa
V = numpy.empty((size, 2))
circle_uniform_pick(size, out = V)
V *= numpy.sqrt(1. - W ** 2)[:, None]
if out is None:
out = numpy.empty((size, 3))
out[:, 0], out[:, 1], out[:, 2] = V[:, 0], V[:, 1], W
# Rotate the samples to the distribution's mu
angle = numpy.arccos(self.mu[2])
if not numpy.allclose(angle, .0):
axis = numpy.array((-self.mu[1], -self.mu[0], 0.))
axis /= numpy.sqrt(numpy.sum(axis ** 2))
rot = numpy.cos(angle) * numpy.identity(3) + numpy.sin(angle) * cross_product_matrix(axis) + (1. - numpy.cos(angle)) * numpy.outer(axis, axis)
out = numpy.dot(out, rot)
# Job done
return out
'''
Returns the probability for X to be generated by the distribution
'''
def pdf(self, X):
if self.kappa == 0.:
return .25 / numpy.pi
else:
return self.pdf_constant * numpy.exp(self.kappa * (self.mu.dot(X) - 1.))
'''
Returns the log-probability for X to be generated by the distribution
'''
def log_pdf(self, X):
if self.kappa == 0.:
return numpy.log(.25 / numpy.pi)
else:
return self.log_pdf_constant + self.kappa * (self.mu.dot(X) - 1.)
def __repr__(self):
return 'VonMisesFisher3(mu = %s, kappa = %f)' % (repr(self.mu), self.kappa)
@staticmethod
def _get_kappa(R_bar):
f = 1. - R_bar ** 2
if numpy.allclose(f, 0.):
raise SingularityError()
return (R_bar * (3. - R_bar ** 2)) / f
'''
Returns an approximation of the most likely VMF to generate samples
- Assumes that the samples lies on the unit sphere
'''
@staticmethod
def estimate(samples):
X = numpy.asarray(samples)
# Estimate for mu
S = numpy.sum(X, axis = 0)
norm_S = numpy.sqrt(numpy.sum(S ** 2))
mu = S / norm_S
# Initial estimate for kappa
R_bar = norm_S / X.shape[0]
kappa = VonMisesFisher3._get_kappa(R_bar)
# Refine kappa estimate
# TODO
# Job done
return VonMisesFisher3(mu, kappa)
'''
Returns an approximation of the most likely VMF to generate samples.
A weight vector specify the relative signifiance of each sample.
- Assumes that the samples lies on the unit sphere
- Assumes that the weight vector sum equals 1.
'''
@staticmethod
def estimate_weighted(samples, log_weights):
X = numpy.asarray(samples)
# Estimate for mu
S = numpy.sum(X * log_weights[:,None], axis = 0)
norm_S = numpy.sqrt(numpy.sum(S ** 2))
mu = S / norm_S
# Initial estimate for kappa
R_bar = norm_S
kappa = VonMisesFisher3._get_kappa(R_bar)
# Refine kappa estimate
# TODO
# Job done
return VonMisesFisher3(mu, kappa)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment