Created
May 11, 2018 14:01
-
-
Save marmakoide/6f55ff99f14c896399c460a38f72c99a to your computer and use it in GitHub Desktop.
Von Mises Fisher distribution in 3d ie. a spherical analog to an isotropic Gaussian distribution.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy | |
''' | |
Pick a point uniformly from the unit circle | |
''' | |
def circle_uniform_pick(size, out = None): | |
if out is None: | |
out = numpy.empty((size, 2)) | |
angle = 2 * numpy.pi * numpy.random.random(size) | |
out[:,0], out[:,1] = numpy.cos(angle), numpy.sin(angle) | |
return out | |
def cross_product_matrix(U): | |
return numpy.array([[ 0., -U[2], U[1]], | |
[ U[2], 0., -U[0]], | |
[-U[1], U[0], 0.]]) | |
class SingularityError(Exception): | |
def __init__(self): | |
pass | |
''' | |
Von Mises-Fisher distribution, ie. isotropic Gaussian distribution defined over | |
a sphere. | |
mu => mean direction | |
kappa => concentration | |
Uses numerical tricks described in "Numerically stable sampling of the von | |
Mises Fisher distribution on S2 (and other tricks)" by Wenzel Jakob | |
Uses maximum likelyhood estimators described in "Modeling Data using | |
Directional Distributions" by Inderjit S. Dhillon and Suvrit Sra | |
''' | |
class VonMisesFisher3(object): | |
def __init__(self, mu, kappa): | |
self.mu = mu | |
self.kappa = kappa | |
self.pdf_constant = self.kappa / ((2 * numpy.pi) * (1. - numpy.exp(-2. * self.kappa))) | |
self.log_pdf_constant = numpy.log(self.pdf_constant) | |
''' | |
Generates samples from the distribution | |
''' | |
def sample(self, size, out = None): | |
# Generate the samples for mu=(0, 0, 1) | |
eta = numpy.random.random(size) | |
tmp = 1. - (((eta - 1.) / eta) * numpy.exp(-2. * self.kappa)) | |
W = 1. + (numpy.log(eta) + numpy.log(tmp)) / self.kappa | |
V = numpy.empty((size, 2)) | |
circle_uniform_pick(size, out = V) | |
V *= numpy.sqrt(1. - W ** 2)[:, None] | |
if out is None: | |
out = numpy.empty((size, 3)) | |
out[:, 0], out[:, 1], out[:, 2] = V[:, 0], V[:, 1], W | |
# Rotate the samples to the distribution's mu | |
angle = numpy.arccos(self.mu[2]) | |
if not numpy.allclose(angle, .0): | |
axis = numpy.array((-self.mu[1], -self.mu[0], 0.)) | |
axis /= numpy.sqrt(numpy.sum(axis ** 2)) | |
rot = numpy.cos(angle) * numpy.identity(3) + numpy.sin(angle) * cross_product_matrix(axis) + (1. - numpy.cos(angle)) * numpy.outer(axis, axis) | |
out = numpy.dot(out, rot) | |
# Job done | |
return out | |
''' | |
Returns the probability for X to be generated by the distribution | |
''' | |
def pdf(self, X): | |
if self.kappa == 0.: | |
return .25 / numpy.pi | |
else: | |
return self.pdf_constant * numpy.exp(self.kappa * (self.mu.dot(X) - 1.)) | |
''' | |
Returns the log-probability for X to be generated by the distribution | |
''' | |
def log_pdf(self, X): | |
if self.kappa == 0.: | |
return numpy.log(.25 / numpy.pi) | |
else: | |
return self.log_pdf_constant + self.kappa * (self.mu.dot(X) - 1.) | |
def __repr__(self): | |
return 'VonMisesFisher3(mu = %s, kappa = %f)' % (repr(self.mu), self.kappa) | |
@staticmethod | |
def _get_kappa(R_bar): | |
f = 1. - R_bar ** 2 | |
if numpy.allclose(f, 0.): | |
raise SingularityError() | |
return (R_bar * (3. - R_bar ** 2)) / f | |
''' | |
Returns an approximation of the most likely VMF to generate samples | |
- Assumes that the samples lies on the unit sphere | |
''' | |
@staticmethod | |
def estimate(samples): | |
X = numpy.asarray(samples) | |
# Estimate for mu | |
S = numpy.sum(X, axis = 0) | |
norm_S = numpy.sqrt(numpy.sum(S ** 2)) | |
mu = S / norm_S | |
# Initial estimate for kappa | |
R_bar = norm_S / X.shape[0] | |
kappa = VonMisesFisher3._get_kappa(R_bar) | |
# Refine kappa estimate | |
# TODO | |
# Job done | |
return VonMisesFisher3(mu, kappa) | |
''' | |
Returns an approximation of the most likely VMF to generate samples. | |
A weight vector specify the relative signifiance of each sample. | |
- Assumes that the samples lies on the unit sphere | |
- Assumes that the weight vector sum equals 1. | |
''' | |
@staticmethod | |
def estimate_weighted(samples, log_weights): | |
X = numpy.asarray(samples) | |
# Estimate for mu | |
S = numpy.sum(X * log_weights[:,None], axis = 0) | |
norm_S = numpy.sqrt(numpy.sum(S ** 2)) | |
mu = S / norm_S | |
# Initial estimate for kappa | |
R_bar = norm_S | |
kappa = VonMisesFisher3._get_kappa(R_bar) | |
# Refine kappa estimate | |
# TODO | |
# Job done | |
return VonMisesFisher3(mu, kappa) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment